Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prec param new #1813

Open
wants to merge 2 commits into
base: jakpiase/ck_tile_gemm_api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@

#include "gemm_basic.hpp"

template <typename ALayout, typename BLayout, typename CLayout>
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeConfig>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using Types = GemmBasicTypeConfig<DataTypeConfig>;

// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;

// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
Expand Down Expand Up @@ -100,30 +108,49 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
return ave_time;
}

float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
template <typename DataType>
float gemm_type_(const gemm_traits& t,
const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Row, Row>(args, s);
return gemm_<Row, Row, Row, DataType>(args, s);
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Col, Row>(args, s);
return gemm_<Row, Col, Row, DataType>(args, s);
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Row, Row>(args, s);
return gemm_<Col, Row, Row, DataType>(args, s);
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Col, Row>(args, s);
return gemm_<Col, Col, Row, DataType>(args, s);
}
else
{
throw std::runtime_error("Wrong! Layouts not supported!\n");
}
}

float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if(t.data_type == "fp16")
{
return gemm_type_<GemmFp16>(t, args, s);
}
else if(t.data_type == "bf16")
{
return gemm_type_<GemmBf16>(t, args, s);
}
else
{
throw std::runtime_error("Wrong! Data type not supported!\n");
}
Comment on lines +140 to +151
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have also support for fp8. Take a look in here:

using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;
We have fp8 MFMA instructions added.

}

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
Expand Down
31 changes: 23 additions & 8 deletions example/ck_tile/03_gemm/gemm_basic.hpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be renamed to gemm.hpp - this is our GEMM host API definition.

Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"

struct GemmFp16
{
};

struct GemmBf16
{
};
Comment on lines +13 to +19
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just enumerator class ?


template <typename DataType>
struct GemmBasicTypeConfig;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this structure to GemmTypeConfig.

Comment on lines 21 to 22
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation. Describe what this class defines ( supported data type configurations).


template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
struct GemmBasicTypeConfig<GemmFp16>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
Expand All @@ -23,6 +31,15 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};

template <>
struct GemmBasicTypeConfig<GemmBf16>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};

template <typename T>
struct DataTypeTraits;

Expand All @@ -44,13 +61,11 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16";
};

using Types = GemmBasicTypeConfig<ck_tile::half_t>;

// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};

using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
Expand Down
84 changes: 62 additions & 22 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awkward. You repeat GemmBasicTypeConfig in every function... Why not just pass type string and layouts as strings ? We should manage this inside our host API and do not repeat the same logic at every step.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

template <typename ALayout, typename BLayout, typename CLayout>
template <typename ALayout, typename BLayout, typename CLayout, typename DataType>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
Expand All @@ -16,6 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
using Types = GemmBasicTypeConfig<DataType>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using CDataType = typename Types::CDataType;

ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
Expand Down Expand Up @@ -50,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}

template <typename ALayout, typename BLayout, typename CLayout>
template <typename ALayout, typename BLayout, typename CLayout, typename DataType>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
Expand All @@ -61,6 +66,12 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;

using Types = GemmBasicTypeConfig<DataType>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;

ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
Expand Down Expand Up @@ -129,18 +140,18 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
invoke_gemm<ALayout, BLayout, CLayout, DataType>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);

c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
Expand Down Expand Up @@ -209,33 +220,62 @@ int run_gemm_example_with_layouts(int argc,
return pass;
}

int run_gemm_example(int argc, char* argv[])
template <typename DataType>
int run_gemm_example_with_datatype(int argc,
char* argv[],
const std::string& a_layout,
const std::string& b_layout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
return run_gemm_example_with_layouts<Row, Row, Row, DataType>(
argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
return run_gemm_example_with_layouts<Row, Col, Row, DataType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
return run_gemm_example_with_layouts<Col, Col, Row, DataType>(
argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
return run_gemm_example_with_layouts<Col, Row, Row, DataType>(
argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}

int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string prec = arg_parser.get_str("prec");

if(prec == "fp16")
{
return run_gemm_example_with_datatype<GemmFp16>(argc, argv, a_layout, b_layout);
}
else if(prec == "bf16")
{
return run_gemm_example_with_datatype<GemmBf16>(argc, argv, a_layout, b_layout);
}
else
{
throw std::runtime_error("Unsupported data type!");
}
}