forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FP5 E2M2 support from upstream (pytorch#399)
* first update from upstream * add some primitives to support fp5 * binding for ExMy * add QuantLlmLinear * fix * update README * update README * remove fp6_linear from C++ * fix * fix * fix * update * add more experimental config * update * add from tc_fpx * remove redundant code * fix import * fix test * avoid division by 0 * add subclass. use uint8 * subclass API * update doc * remove unused op * update * rename. update * update docs * rename * fix for PyTorch 2.2 * _implements -> implements * set CUDA context * fix __repr__
- Loading branch information
1 parent
5913236
commit 6b34ba5
Showing
21 changed files
with
958 additions
and
847 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import copy | ||
|
||
import pytest | ||
import torch | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torchao.prototype.quant_llm import ( | ||
QuantLlmLinearWeight, | ||
quant_llm_fpx_weight_only, | ||
to_scaled_tc_fpx, | ||
from_scaled_tc_fpx, | ||
) | ||
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 | ||
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 | ||
from torchao.quantization.quant_api import quantize | ||
|
||
|
||
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
_FPx_DTYPES = [(3, 2), (2, 2)] | ||
|
||
|
||
class TestQuantLlmLinearWeight(TestCase): | ||
@parametrize("device", _DEVICES) | ||
def test_pack_tc_fp6_correctness(self, device): | ||
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) | ||
|
||
expected = _pack_tc_fpx(x, 6) | ||
actual = _pack_tc_fp6(x) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@parametrize("ebits,mbits", _FPx_DTYPES) | ||
@parametrize("device", _DEVICES) | ||
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device): | ||
x = torch.randn(256, 64, device=device) | ||
|
||
expected = to_scaled_tc_fpx(x, ebits, mbits) | ||
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@parametrize("ebits,mbits", _FPx_DTYPES) | ||
@parametrize("device", _DEVICES) | ||
def test_from_tc_fpx_correctness(self, ebits, mbits, device): | ||
x = torch.randn(256, 64, device=device) * 100 | ||
|
||
# quantize and dequantize so that the values are exactly representable in FPx | ||
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits) | ||
|
||
tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits) | ||
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale) | ||
torch.testing.assert_close(actual, x) | ||
|
||
@parametrize("ebits,mbits", _FPx_DTYPES) | ||
@parametrize("device", _DEVICES) | ||
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device): | ||
M, N = 256, 64 | ||
nbits = 1 + ebits + mbits | ||
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device) | ||
scale = torch.randn(M, device=device) | ||
|
||
expected = from_scaled_tc_fpx(x, ebits, mbits, scale) | ||
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@parametrize("ebits,mbits", _FPx_DTYPES) | ||
@parametrize("leading_dims", [(4,), (2, 4)]) | ||
@parametrize("bias", [False, True]) | ||
def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims): | ||
OC, IC = 256, 64 | ||
device = "cuda" | ||
|
||
fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half) | ||
fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None | ||
|
||
fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits) | ||
|
||
x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) | ||
out = torch.nn.functional.linear(x, fpx_weight, fp16_bias) | ||
assert out.shape == leading_dims + (OC,) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@parametrize("ebits,mbits", _FPx_DTYPES) | ||
@parametrize("bias", [False, True]) | ||
def test_quant_llm_quantize(self, ebits, mbits, bias): | ||
N, OC, IC = 4, 256, 64 | ||
device = "cuda" | ||
|
||
linear = torch.nn.Linear(IC, OC, bias=bias, device=device) | ||
fpx_linear = copy.deepcopy(linear) | ||
quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) | ||
|
||
x = torch.randn(N, IC, device=device, dtype=torch.half) | ||
expected = fpx_linear(x) | ||
actual = torch.compile(fpx_linear, fullgraph=True)(x) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
instantiate_parametrized_tests(TestQuantLlmLinearWeight) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,69 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import TestCase, IS_FBCODE | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torch.testing._internal.optests import opcheck | ||
import torchao | ||
from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 | ||
import unittest | ||
from parameterized import parameterized | ||
from torchao.utils import is_fbcode | ||
from torchao.prototype.quant_llm import from_scaled_tc_fpx | ||
import pytest | ||
|
||
if is_fbcode(): | ||
pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels") | ||
|
||
try: | ||
import torchao.ops | ||
except RuntimeError: | ||
pytest.skip("torchao.ops not available") | ||
|
||
|
||
# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): | ||
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) | ||
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning") | ||
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") | ||
class TestOps(TestCase): | ||
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): | ||
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. | ||
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) | ||
fp16_scale = torch.rand(OC).half() + 0.5 | ||
fp16_activation = torch.rand(BS, IC).half() + 0.5 | ||
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||
def test_fp6_llm_linear(self): | ||
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): | ||
# Randomly initialize each byte | ||
nbits = 1 + ebits + mbits | ||
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) | ||
scale = torch.rand(OC).half() + 0.5 | ||
fp16_act = torch.rand(BS, IC).half() + 0.5 | ||
return fpx_weight.to(device), scale.to(device), fp16_act.to(device) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@parametrize("ebits,mbits", [(3, 2), (2, 2)]) | ||
def test_quant_llm_linear(self, ebits, mbits): | ||
BS = 2 | ||
OC = 256 | ||
IC = 256 | ||
splitK = 1 | ||
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") | ||
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") | ||
|
||
# smoke test | ||
torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) | ||
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) | ||
|
||
# comprehensive testing | ||
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] | ||
opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) | ||
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) | ||
@parametrize("ebits,mbits", [(3, 2), (2, 2)]) | ||
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): | ||
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py | ||
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") | ||
|
||
results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) | ||
|
||
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py | ||
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) | ||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||
def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): | ||
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") | ||
fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half() | ||
results_fp16 = fp16_act @ fp16_weight.T | ||
|
||
results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) | ||
error = (results_fpx - results_fp16).abs().mean() | ||
gt = results_fp16.abs().mean() | ||
relative_error = error / gt | ||
assert relative_error < 1e-3 | ||
|
||
fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] | ||
results_fp16 = fp16_activation @ fp16_weight.T | ||
|
||
error = (results_fp6 - results_fp16).abs() | ||
relative_error = error / results_fp16.abs() | ||
assert relative_error.mean() < 1e-2 | ||
instantiate_parametrized_tests(TestOps) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
run_tests() |
Oops, something went wrong.