Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft ukernel selection logic #1652

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" ON)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: nocommit

if(TORCHAO_BUILD_KLEIDIAI)
message(STATUS "Building with Arm KleidiAI library")
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
Expand Down
24 changes: 12 additions & 12 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ void check_embedding_inputs(
int packed_embedding_dim = (embedding_dim * weight_nbit) / 8;
TORCHAO_CHECK(
packed_weight_qvals.size(0) ==
(torchao::ops::PackedWeightsHeader::size() +
(torchao::ops::PackedWeightsFormat::serialized_size() +
(num_embeddings * packed_embedding_dim)),
"packed_weight_qvals is not the correct size");

// Check header
auto header = torchao::ops::PackedWeightsHeader::read(
// Check packed_weights_format
auto packed_weights_format = torchao::ops::PackedWeightsFormat::deserialize(
packed_weight_qvals.const_data_ptr());
TORCHAO_CHECK(
header ==
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
packed_weights_format ==
torchao::ops::embedding_xbit::get_packed_weights_format_universal(
weight_nbit,
/*min_value_chunk_size=*/32,
/*max_value_chunk_size=*/128),
Expand Down Expand Up @@ -151,7 +151,7 @@ Tensor embedding_out_cpu(
embedding_dim,
group_size,
packed_weight_qvals.const_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
torchao::ops::PackedWeightsFormat::serialized_size(),
weight_scales.const_data_ptr<float>(),
weight_zeros_ptr,
index);
Expand Down Expand Up @@ -222,23 +222,23 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");

auto out = torch::empty(
torchao::ops::PackedWeightsHeader::size() +
torchao::ops::PackedWeightsFormat::serialized_size() +
(num_embeddings * packed_embedding_dim))
.to(torch::kInt8);

auto header =
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
auto packed_weights_format =
torchao::ops::embedding_xbit::get_packed_weights_format_universal(
weight_nbit,
/*min_value_chunk_size=*/32,
/*max_value_chunk_size=*/128);
header.write(out.mutable_data_ptr());
packed_weights_format.serialize(out.mutable_data_ptr());

torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) {
#if defined(__aarch64__) || defined(__ARM_NEON)
torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals<
weight_nbit>(
out.mutable_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
torchao::ops::PackedWeightsFormat::serialized_size(),
embedding_dim,
weight_qvals.const_data_ptr<int8_t>(),
idx);
Expand All @@ -261,7 +261,7 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) {
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
return torch::empty(
torchao::ops::PackedWeightsHeader::size() +
torchao::ops::PackedWeightsFormat::serialized_size() +
(num_embeddings * packed_embedding_dim))
.to("meta");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

namespace torchao::ops::embedding_xbit {

inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
inline torchao::ops::PackedWeightsFormat get_packed_weights_format_universal(
int weight_nbit,
int min_value_chunk_size,
int max_value_chunk_size,
int version = 1) {
return torchao::ops::PackedWeightsHeader(
torchao::ops::PackedWeightsFormat::embedding_xbit_universal,
return torchao::ops::PackedWeightsFormat(
torchao::ops::PackedWeightsType::embedding_xbit_universal,
{version,
weight_nbit,
min_value_chunk_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@ cmake_minimum_required(VERSION 3.19)

include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake)

add_compile_options(-Wno-unused-function -Wno-unused-variable) # For some reason cpuinfo package has unused functions/variables
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix it upstream?


include(FetchContent)
FetchContent_Declare(cpuinfo
GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git
GIT_TAG main) # need main for benchmark::benchmark
FetchContent_MakeAvailable(
cpuinfo)

find_package(Torch REQUIRED)
add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT
linear_8bit_act_xbit_weight.cpp
op_linear_8bit_act_xbit_weight_aten.cpp
)
target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo)
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE USE_ATEN=1)
Expand All @@ -37,4 +47,5 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo)
endif()
Loading
Loading