Skip to content

Commit

Permalink
Add FP5 E2M2 support from upstream (#399)
Browse files Browse the repository at this point in the history
* first update from upstream

* add some primitives to support fp5

* binding for ExMy

* add QuantLlmLinear

* fix

* update README

* update README

* remove fp6_linear from C++

* fix

* fix

* fix

* update

* add more experimental config

* update

* add from tc_fpx

* remove redundant code

* fix import

* fix test

* avoid division by 0

* add subclass. use uint8

* subclass API

* update doc

* remove unused op

* update

* rename. update

* update docs

* rename

* fix for PyTorch 2.2

* _implements -> implements

* set CUDA context

* fix __repr__
  • Loading branch information
gau-nernst authored Jun 25, 2024
1 parent 96d49cd commit 70aef5d
Show file tree
Hide file tree
Showing 21 changed files with 958 additions and 847 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})

* [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
* [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701)
* [fp6](torchao/prototype/fp6_llm/) for 2x faster inference over fp16 with an easy to use wrapper api `convert_fp6_llm(model)`
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())`

## Composability

Expand Down Expand Up @@ -104,7 +104,7 @@ python setup.py install
* [GaLore](torchao/prototype/galore/) a drop for the Adam Optimizer that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch
* [DoRA](torchao/prototype/dora) a newer replacement for QLoRA with more promising convergence characteristics
* [Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm)
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/quant_llm](torchao/prototype/quant_llm)
* [vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common)
* [andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile

Expand Down
29 changes: 14 additions & 15 deletions benchmarks/benchmark_fp6_llm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import torch
from torch import nn
from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
from torch.utils.benchmark import Timer
import pandas as pd
import torch.nn.functional as F
from torchao.prototype.quant_llm import QuantLlmLinearWeight
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_linear = Fp6LlmLinear(fp6_weight, scales)
fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2)

fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None]
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
fp6_output = fp6_linear(fp16_act)
fp16_output = fp16_linear(fp16_act)
fp6_output = F.linear(fp16_act, fp6_weight)
fp16_output = F.linear(fp16_act, fp16_weight)

fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange()
fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange()
fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)

# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
# doesn't seem to be the right way to check for correctness
Expand All @@ -29,9 +28,9 @@ def benchmark(m: int, k: int, n: int):
"m": m,
"k": k,
"n": n,
"fp6_latency (ms)": fp6_measurement.median * 1000,
"fp16_latency (ms)": fp16_measurement.median * 1000,
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
"fp6_latency (ms)": fp6_time,
"fp16_latency (ms)": fp16_time,
"speedup (d/s)": fp16_time / fp6_time,
"correct": correct,
}

Expand Down
106 changes: 0 additions & 106 deletions test/prototype/test_fp6_llm.py

This file was deleted.

106 changes: 106 additions & 0 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import copy

import pytest
import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.prototype.quant_llm import (
QuantLlmLinearWeight,
quant_llm_fpx_weight_only,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
)
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
from torchao.quantization.quant_api import quantize


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_FPx_DTYPES = [(3, 2), (2, 2)]


class TestQuantLlmLinearWeight(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)

expected = _pack_tc_fpx(x, 6)
actual = _pack_tc_fp6(x)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("device", _DEVICES)
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device)

expected = to_scaled_tc_fpx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_tc_fpx_correctness(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device) * 100

# quantize and dequantize so that the values are exactly representable in FPx
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits)

tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits)
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale)
torch.testing.assert_close(actual, x)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
M, N = 256, 64
nbits = 1 + ebits + mbits
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device)
scale = torch.randn(M, device=device)

expected = from_scaled_tc_fpx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("leading_dims", [(4,), (2, 4)])
@parametrize("bias", [False, True])
def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims):
OC, IC = 256, 64
device = "cuda"

fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half)
fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None

fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits)

x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
out = torch.nn.functional.linear(x, fpx_weight, fp16_bias)
assert out.shape == leading_dims + (OC,)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("bias", [False, True])
def test_quant_llm_quantize(self, ebits, mbits, bias):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fpx_linear = copy.deepcopy(linear)
quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestQuantLlmLinearWeight)


if __name__ == "__main__":
run_tests()
75 changes: 42 additions & 33 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,69 @@
import torch
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.optests import opcheck
import torchao
from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2
import unittest
from parameterized import parameterized
from torchao.utils import is_fbcode
from torchao.prototype.quant_llm import from_scaled_tc_fpx
import pytest

if is_fbcode():
pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels")

try:
import torchao.ops
except RuntimeError:
pytest.skip("torchao.ops not available")


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace)
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
class TestOps(TestCase):
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(OC).half() + 0.5
fp16_activation = torch.rand(BS, IC).half() + 0.5
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_llm_linear(self):
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
scale = torch.rand(OC).half() + 0.5
fp16_act = torch.rand(BS, IC).half() + 0.5
return fpx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear(self, ebits, mbits):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")

# smoke test
torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")

results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half()
results_fp16 = fp16_act @ fp16_weight.T

results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
error = (results_fpx - results_fp16).abs().mean()
gt = results_fp16.abs().mean()
relative_error = error / gt
assert relative_error < 1e-3

fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = fp16_activation @ fp16_weight.T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
assert relative_error.mean() < 1e-2
instantiate_parametrized_tests(TestOps)


if __name__ == "__main__":
unittest.main()
run_tests()
Loading

0 comments on commit 70aef5d

Please sign in to comment.