-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CUTLASS-based row-wise scaled sparse FP8 kernel
- Loading branch information
1 parent
867a91f
commit e1c9bd0
Showing
9 changed files
with
985 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import itertools | ||
import random | ||
|
||
import pytest | ||
import torch | ||
from torch.testing._internal.common_cuda import SM90OrLater | ||
|
||
from torchao.dtypes import ( | ||
Float8Layout, | ||
to_affine_quantized_floatx, | ||
) | ||
from torchao.ops import ( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
to_sparse_semi_structured_cutlass_sm9x_f8, | ||
) | ||
|
||
|
||
X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)] | ||
XQ_WQ_DTYPES = [ | ||
(torch.float8_e5m2, torch.float8_e4m3fn), | ||
(torch.float8_e4m3fn, torch.float8_e4m3fn), | ||
] | ||
BATCH_SIZE = [1, 4, 8, 16, 32, 64] | ||
SIZE_MNK = [ | ||
(2, 512, 128), | ||
(3, 2048, 2048), | ||
(4, 3584, 640), | ||
(13, 8704, 8576), | ||
(26, 18944, 1664), | ||
(67, 6656, 1408), | ||
] | ||
USE_BIAS = [False, True] | ||
BIAS_DTYPE = [torch.float16] | ||
TEST_PARAMS = list( | ||
itertools.product( | ||
X_W_DTYPES, | ||
XQ_WQ_DTYPES, | ||
BATCH_SIZE, | ||
SIZE_MNK, | ||
USE_BIAS, | ||
BIAS_DTYPE, | ||
) | ||
) | ||
|
||
|
||
# FIXME: remove this! | ||
X_W_DTYPES = [(torch.float16, torch.float16)] | ||
XQ_WQ_DTYPES = [(torch.float8_e5m2, torch.float8_e4m3fn)] | ||
BATCH_SIZE = [1] | ||
SIZE_MNK = [(32, 64, 128)] | ||
USE_BIAS = [True] | ||
BIAS_DTYPE = [torch.float16] | ||
TEST_PARAMS = list( | ||
itertools.product( | ||
X_W_DTYPES, | ||
XQ_WQ_DTYPES, | ||
BATCH_SIZE, | ||
SIZE_MNK, | ||
USE_BIAS, | ||
BIAS_DTYPE, | ||
) | ||
) | ||
|
||
|
||
def rand_sparse_semi_structured(r, c, dtype, device, choice=None): | ||
pattern = "2by4" if dtype != torch.float32 else "1by2" | ||
if pattern == "1by2": | ||
ksparse = 2 | ||
choices = [[0, 1], [1, 0]] | ||
elif pattern == "2by4": | ||
ksparse = 4 | ||
choices = [ | ||
[1, 1, 0, 0], | ||
[1, 0, 1, 0], | ||
[1, 0, 0, 1], | ||
[0, 1, 1, 0], | ||
[0, 1, 0, 1], | ||
[0, 0, 1, 1], | ||
] | ||
assert c % ksparse == 0 | ||
mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)] | ||
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device) | ||
dense = torch.randn(r, c, dtype=dtype, device=device) | ||
dense[dense == 0] = 1 # To prevent zeros except where mask applied. | ||
dense = dense.masked_fill(~mask, 0) | ||
return dense | ||
|
||
|
||
def run_test_for_op( | ||
op, | ||
x_dtype, | ||
w_dtype, | ||
xq_dtype, | ||
wq_dtype, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
size_m, size_n, size_k = size_mnk | ||
|
||
x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda") | ||
w = rand_sparse_semi_structured(size_n, size_k, dtype=w_dtype, device="cuda") | ||
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None | ||
|
||
block_size = [1] * (x.dim() - 1) + [x.shape[-1]] | ||
x_aqt = to_affine_quantized_floatx( | ||
input_float=x, | ||
target_dtype=xq_dtype, | ||
block_size=block_size, | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
|
||
block_size = [1] * (w.dim() - 1) + [w.shape[-1]] | ||
w_aqt = to_affine_quantized_floatx( | ||
input_float=w, | ||
target_dtype=wq_dtype, | ||
block_size=block_size, | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq) | ||
wq_sp_scales = wq_scales | ||
|
||
xq_2d = xq.view(-1, xq.shape[-1]) | ||
size_m_2d = xq_2d.shape[0] | ||
output_ref = ( | ||
(xq_2d.float() @ wq.float().T) | ||
* xq_scales.view(size_m_2d, 1) | ||
* wq_scales.view(1, size_n) | ||
) | ||
if bias is not None: | ||
output_ref += bias | ||
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,)) | ||
|
||
fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias) | ||
try: | ||
output = op(*fn_inputs) | ||
except NotImplementedError: | ||
pytest.xfail("operator not implemented") | ||
|
||
# FIXME: remove this! | ||
d_ref = output_ref | ||
d = output | ||
print( | ||
f"Sum of relative errors d vs. d_ref : {torch.sum(torch.abs(d - d_ref) / torch.abs(d_ref)).item():8.2f}" | ||
) | ||
print() | ||
d_ref = d_ref.flatten().to(torch.float32) | ||
d = d.flatten().to(torch.float32) | ||
topk = 10 | ||
print(f"Top {topk} relative errors d vs. d_ref :") | ||
print(" d_ref d") | ||
print("------------+------------") | ||
values, indices = torch.topk(torch.abs(d - d_ref) / torch.abs(d_ref), topk) | ||
for index in indices: | ||
print(f"{d_ref[index].item():12.5e} {d[index].item():12.5e}") | ||
print() | ||
|
||
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-3) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") | ||
@pytest.mark.parametrize( | ||
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype", | ||
TEST_PARAMS, | ||
) | ||
def test_rowwise_scaled_liner_sparse_cutlass_f8f8( | ||
x_w_dtypes, | ||
xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
run_test_for_op( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
*x_w_dtypes, | ||
*xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
) |
1 change: 1 addition & 0 deletions
1
torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#include <cutlass/cutlass.h> | ||
#include <torch/library.h> | ||
|
||
#include "rowwise_scaled_linear_cutlass.cuh" | ||
|
1 change: 1 addition & 0 deletions
1
torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#include <cutlass/cutlass.h> | ||
#include <torch/library.h> | ||
|
||
#include "rowwise_scaled_linear_cutlass.cuh" | ||
|
Oops, something went wrong.