Skip to content

Commit

Permalink
change kleidiai interface
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Feb 7, 2025
1 parent 5c45936 commit 4ab3e84
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@

#include <kai/kai_common.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>

#ifdef TORCHAO_ENABLE_ARM_I8MM
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
#endif // TORCHAO_ENABLE_ARM_I8MM

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h>

Expand Down Expand Up @@ -43,14 +50,16 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {

using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;

size_t activation_data_size(const Ukernel ukernel, int m, int k) {
size_t activation_data_size(int mr, int kr, int sr, int m, int k) {
auto lhs_packing = get_lhs_packing();
return lhs_packing.get_lhs_packed_size(
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
m, k, mr, kr, sr);
}

void prepare_activation_data(
const Ukernel ukernel,
int mr,
int kr,
int sr,
void* activation_data,
int m,
int k,
Expand All @@ -60,29 +69,31 @@ void prepare_activation_data(
lhs_pack.run_lhs_pack(
m,
k,
ukernel.get_mr(),
ukernel.get_kr(),
ukernel.get_sr(),
mr,
kr,
sr,
/*m_index_start=*/0,
activations,
/*lhs_stride=*/k * sizeof(float),
activation_data);
}

size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) {
size_t weight_data_size(int nr, int kr, int sr, int n, int k, int group_size) {
auto rhs_pack = get_rhs_packing();
return rhs_pack.get_rhs_packed_size(
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
nr,
kr,
sr,
group_size,
kai_datatype::kai_dt_bf16);
}

void prepare_weight_data(
const Ukernel ukernel,
int nr,
int kr,
int sr,
void* weight_data,
int n,
int k,
Expand Down Expand Up @@ -134,9 +145,9 @@ void prepare_weight_data(
/*groups=*/1,
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
nr,
kr,
sr,
group_size,
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()),
/*rhs_stride=*/roundup(k, 2) / 2,
Expand All @@ -148,5 +159,99 @@ void prepare_weight_data(
/*qparams=*/&qparams);
}


size_t get_preferred_alignement() {
return 16;
}


#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \
return weight_data_size(nr, kr, sr, n, k, group_size); \
} \
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \
void* weight_data, \
int n, \
int k, \
int group_size, \
const int8_t* weight_qvals, \
const float* weight_scales, \
const int8_t* weight_zeros, \
const float* bias) { \
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \
}

#define DEFINE_ACTIVATION_DATA_FNS(mr, kr, sr) \
size_t activation_data_size_mr##mr##_kr##kr##_sr##sr(int m, int k, int group_size) { \
(void)group_size; \
return activation_data_size(mr, kr, sr, m, k); \
} \
void prepare_activation_data_mr##mr##_kr##kr##_sr##sr(void* activation_data, int m, int k, int group_size, const float* activations) { \
(void)group_size; \
prepare_activation_data(mr, kr, sr, activation_data, m, k, activations); \
}

// TODO: first and suffix need to be better, e.g., parametrized by mr, nr, etc
// But I don't quite follow the naming convention for KleidiAI
#define DEFINE_KERNEL_FNS(first, suffix) \
namespace impl_##suffix { \
const Ukernel get_ukernel() { \
return Ukernel{ \
.get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.run_matmul = kai_run_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix \
}; \
} \
void kernel( \
float32_t* output, \
int output_m_stride, \
int m, \
int n, \
int k, \
int group_size, \
const void* weight_data, \
const void* activation_data, \
float clamp_min, \
float clamp_max) { \
get_ukernel().run_matmul( \
m, \
n, \
k, \
group_size, \
activation_data, \
weight_data, \
output, \
/*dst_stride_row=*/ output_m_stride * sizeof(float), \
/*dst_stride_col=*/ sizeof(float), \
/*clamp_min=*/std::numeric_limits<float>::lowest(), \
/*clamp_max=*/std::numeric_limits<float>::max() \
); \
} \
}



DEFINE_WEIGHT_DATA_FNS(/*nr*/8, /*kr*/16, /*sr*/2)
DEFINE_ACTIVATION_DATA_FNS(/*mr*/1, /*kr*/16, /*sr*/2)
DEFINE_KERNEL_FNS(1x8, 8x8_1x8x32_neon_dotprod)
DEFINE_KERNEL_FNS(1x8, 4x8_1x4x32_neon_dotprod)

#ifdef TORCHAO_ENABLE_ARM_I8MM
DEFINE_KERNEL_FNS(4x8, 4x8_8x4x32_neon_i8mm)
DEFINE_KERNEL_FNS(4x8, 8x8_4x8x32_neon_i8mm)
#endif // TORCHAO_ENABLE_ARM_I8MM

#undef DEFINE_WEIGHT_DATA_FNS
#undef DEFINE_ACTIVATION_DATA_FNS
#undef DEFINE_KERNEL_FNS

} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
#include <unordered_map>

#if defined(TORCHAO_ENABLE_KLEIDI)
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
#if defined (TORCHAO_ENABLE_ARM_I8MM)
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h>
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
#endif // TORCHAO_ENABLE_ARM_I8MM
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
#endif // TORCHAO_ENABLE_KLEIDI

namespace torchao::ops::linear_8bit_act_xbit_weight {
Expand Down Expand Up @@ -208,104 +203,102 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
"Kernel expects has_bias=true, but packed_weights have has_bias=" + std::to_string(kleidi_ai_format.has_bias)
);
}
namespace op = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p;

if (nr == 8 && kr == 16 && sr == 2) {
#if defined (TORCHAO_ENABLE_ARM_I8MM)
if (cpuinfo_has_arm_i8mm()) {
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32;
auto uk = kernel::get_ukernel();
assert (nr == uk.get_nr());
assert (kr == uk.get_kr());
assert (sr == uk.get_sr());
table.register_ukernel_config(
format,
uarch,
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/kernel::get_preferred_alignement(),
/*weight_packing*/
{
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&kernel::weight_data_size,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
},
/*kernels*/
{{
auto uk = op::8x8_4x8x32_neon_i8mm::get_ukernel();
assert (nr == uk.get_nr());
assert (kr == uk.get_kr());
assert (sr == uk.get_sr());
table.register_ukernel_config(
format,
uarch,
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/op::get_preferred_alignement(),
/*weight_packing*/
{
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2,
/*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2
},
/*kernels*/
{{
{
/*mr*/static_cast<int>(uk.get_m_step()),
/*activation_data_size_fn*/&kernel::activation_data_size,
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
/*kernel*/&kernel::kernel
/*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2,
/*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2,
/*kernel*/&op::8x8_4x8x32_neon_i8mm::kernel
}
}}
}
);
return;
}}
}
);
return;
}
#endif // TORCHAO_ENABLE_ARM_I8MM

if (cpuinfo_has_arm_neon_dot()) {
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
auto uk = kernel::get_ukernel();
auto uk = op::impl_8x8_1x8x32_neon_dotprod::get_ukernel();
assert (nr == uk.get_nr());
assert (kr == uk.get_kr());
assert (sr == uk.get_sr());
table.register_ukernel_config(
format,
uarch,
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/kernel::get_preferred_alignement(),
/*preferred_alignment*/op::get_preferred_alignement(),
/*weight_packing*/
{
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&kernel::weight_data_size,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
/*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2,
/*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2
},
/*kernels*/
{{
{
/*mr*/static_cast<int>(uk.get_m_step()),
/*activation_data_size_fn*/&kernel::activation_data_size,
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
/*kernel*/&kernel::kernel
/*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2,
/*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2,
/*kernel*/&op::impl_8x8_1x8x32_neon_dotprod::kernel
}
}}
}
);
return;
}
}

if (nr == 4 && kr == 16 && sr == 2) {
if (cpuinfo_has_arm_neon_dot()) {
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;
auto uk = kernel::get_ukernel();
auto uk = op::impl_4x8_1x4x32_neon_dotprod::get_ukernel();
assert (nr == uk.get_nr());
assert (kr == uk.get_kr());
assert (sr == uk.get_sr());
table.register_ukernel_config(
format,
uarch,
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/kernel::get_preferred_alignement(),
/*preferred_alignment*/op::get_preferred_alignement(),
/*weight_packing*/
{
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&kernel::weight_data_size,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
/*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2,
/*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2
},
/*kernels*/
{{
{
/*mr*/static_cast<int>(uk.get_m_step()),
/*activation_data_size_fn*/&kernel::activation_data_size,
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
/*kernel*/&kernel::kernel
/*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2,
/*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2,
/*kernel*/&op::impl_4x8_1x4x32_neon_dotprod::kernel
}
}}
}
);
return;
}
}
}
#endif // TORCHAO_ENABLE_KLEIDI
}
Expand Down

0 comments on commit 4ab3e84

Please sign in to comment.