-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prec param new #1813
base: jakpiase/ck_tile_gemm_api
Are you sure you want to change the base?
Prec param new #1813
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should be renamed to |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,11 +10,19 @@ | |
#include "ck_tile/ops/gemm.hpp" | ||
#include "ck_tile/ops/epilogue.hpp" | ||
|
||
struct GemmFp16 | ||
{ | ||
}; | ||
|
||
struct GemmBf16 | ||
{ | ||
}; | ||
Comment on lines
+13
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about just enumerator class ? |
||
|
||
template <typename DataType> | ||
struct GemmBasicTypeConfig; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please rename this structure to GemmTypeConfig.
Comment on lines
21
to
22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add documentation. Describe what this class defines ( supported data type configurations). |
||
|
||
template <> | ||
struct GemmBasicTypeConfig<ck_tile::half_t> | ||
struct GemmBasicTypeConfig<GemmFp16> | ||
{ | ||
using ADataType = ck_tile::half_t; | ||
using BDataType = ck_tile::half_t; | ||
|
@@ -23,6 +31,15 @@ struct GemmBasicTypeConfig<ck_tile::half_t> | |
// ToDo: Add more bias config to support different categories of GEMM. | ||
}; | ||
|
||
template <> | ||
struct GemmBasicTypeConfig<GemmBf16> | ||
{ | ||
using ADataType = ck_tile::bf16_t; | ||
using BDataType = ck_tile::bf16_t; | ||
using AccDataType = float; | ||
using CDataType = ck_tile::bf16_t; | ||
}; | ||
|
||
template <typename T> | ||
struct DataTypeTraits; | ||
|
||
|
@@ -44,13 +61,11 @@ struct DataTypeTraits<ck_tile::half_t> | |
static constexpr const char* name = "fp16"; | ||
}; | ||
|
||
using Types = GemmBasicTypeConfig<ck_tile::half_t>; | ||
|
||
// Specific type aliases for easy access | ||
using ADataType = Types::ADataType; | ||
using BDataType = Types::BDataType; | ||
using AccDataType = Types::AccDataType; | ||
using CDataType = Types::CDataType; | ||
template <> | ||
struct DataTypeTraits<ck_tile::bf16_t> | ||
{ | ||
static constexpr const char* name = "bf16"; | ||
}; | ||
|
||
using Row = ck_tile::tensor_layout::gemm::RowMajor; | ||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is awkward. You repeat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have also support for fp8. Take a look in here:
composable_kernel/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Lines 125 to 158 in 3de7bd6