Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP5 E2M2 support from upstream #399

Merged
merged 33 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ce8fd7d
first update from upstream
gau-nernst Jun 19, 2024
d8bd7b6
add some primitives to support fp5
gau-nernst Jun 19, 2024
129adff
binding for ExMy
gau-nernst Jun 19, 2024
64c6cee
add QuantLlmLinear
gau-nernst Jun 19, 2024
38ad773
fix
gau-nernst Jun 19, 2024
057367e
update README
gau-nernst Jun 19, 2024
a6ed669
update README
gau-nernst Jun 19, 2024
0d409a8
remove fp6_linear from C++
gau-nernst Jun 19, 2024
0bd9ee1
fix
gau-nernst Jun 19, 2024
8fbd3d4
fix
gau-nernst Jun 19, 2024
5906eed
fix
gau-nernst Jun 19, 2024
3b008d5
update
gau-nernst Jun 19, 2024
9076e58
add more experimental config
gau-nernst Jun 19, 2024
442e9c5
update
gau-nernst Jun 19, 2024
d2d8019
add from tc_fpx
gau-nernst Jun 19, 2024
bb52ad0
remove redundant code
gau-nernst Jun 19, 2024
80661ab
fix import
gau-nernst Jun 19, 2024
edfbe3d
fix test
gau-nernst Jun 19, 2024
0ecdd86
avoid division by 0
gau-nernst Jun 19, 2024
4e44a8d
Merge branch 'pytorch:main' into fp5_llm
gau-nernst Jun 24, 2024
e6c7d6b
add subclass. use uint8
gau-nernst Jun 24, 2024
ca43bf8
subclass API
gau-nernst Jun 24, 2024
8de2722
update doc
gau-nernst Jun 24, 2024
50bfe82
remove unused op
gau-nernst Jun 24, 2024
b9375a4
update
gau-nernst Jun 24, 2024
3072257
rename. update
gau-nernst Jun 24, 2024
7b822ef
update docs
gau-nernst Jun 24, 2024
ca45dda
rename
gau-nernst Jun 24, 2024
36fe61e
fix for PyTorch 2.2
gau-nernst Jun 24, 2024
30608f3
Merge branch 'main' into fp5_llm
gau-nernst Jun 24, 2024
57ad040
_implements -> implements
gau-nernst Jun 24, 2024
ceaa71c
set CUDA context
gau-nernst Jun 25, 2024
4e585e9
fix __repr__
gau-nernst Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})

* [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
* [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701)
* [fp6](torchao/prototype/fp6_llm/) for 2x faster inference over fp16 with an easy to use wrapper api `convert_fp6_llm(model)`
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())`

## Composability

Expand Down Expand Up @@ -104,7 +104,7 @@ python setup.py install
* [GaLore](torchao/prototype/galore/) a drop for the Adam Optimizer that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch
* [DoRA](torchao/prototype/dora) a newer replacement for QLoRA with more promising convergence characteristics
* [Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm)
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/quant_llm](torchao/prototype/quant_llm)
* [vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common)
* [andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile

Expand Down
29 changes: 14 additions & 15 deletions benchmarks/benchmark_fp6_llm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import torch
from torch import nn
from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
from torch.utils.benchmark import Timer
import pandas as pd
import torch.nn.functional as F
from torchao.prototype.quant_llm import QuantLlmLinearWeight
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_linear = Fp6LlmLinear(fp6_weight, scales)
fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2)

fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None]
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
fp6_output = fp6_linear(fp16_act)
fp16_output = fp16_linear(fp16_act)
fp6_output = F.linear(fp16_act, fp6_weight)
fp16_output = F.linear(fp16_act, fp16_weight)

fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange()
fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange()
fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)

# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
# doesn't seem to be the right way to check for correctness
Expand All @@ -29,9 +28,9 @@ def benchmark(m: int, k: int, n: int):
"m": m,
"k": k,
"n": n,
"fp6_latency (ms)": fp6_measurement.median * 1000,
"fp16_latency (ms)": fp16_measurement.median * 1000,
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
"fp6_latency (ms)": fp6_time,
"fp16_latency (ms)": fp16_time,
"speedup (d/s)": fp16_time / fp6_time,
"correct": correct,
}

Expand Down
106 changes: 0 additions & 106 deletions test/prototype/test_fp6_llm.py

This file was deleted.

106 changes: 106 additions & 0 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import copy

import pytest
import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.prototype.quant_llm import (
QuantLlmLinearWeight,
quant_llm_fpx_weight_only,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
)
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
from torchao.quantization.quant_api import quantize


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_FPx_DTYPES = [(3, 2), (2, 2)]


class TestQuantLlmLinearWeight(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)

expected = _pack_tc_fpx(x, 6)
actual = _pack_tc_fp6(x)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("device", _DEVICES)
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device)

expected = to_scaled_tc_fpx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_tc_fpx_correctness(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device) * 100

# quantize and dequantize so that the values are exactly representable in FPx
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits)

tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits)
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale)
torch.testing.assert_close(actual, x)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
M, N = 256, 64
nbits = 1 + ebits + mbits
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device)
scale = torch.randn(M, device=device)

expected = from_scaled_tc_fpx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("leading_dims", [(4,), (2, 4)])
@parametrize("bias", [False, True])
def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims):
OC, IC = 256, 64
device = "cuda"

fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half)
fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None

fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits)

x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
out = torch.nn.functional.linear(x, fpx_weight, fp16_bias)
assert out.shape == leading_dims + (OC,)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("bias", [False, True])
def test_quant_llm_quantize(self, ebits, mbits, bias):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fpx_linear = copy.deepcopy(linear)
quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestQuantLlmLinearWeight)


if __name__ == "__main__":
run_tests()
75 changes: 42 additions & 33 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,69 @@
import torch
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.optests import opcheck
import torchao
from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2
import unittest
from parameterized import parameterized
from torchao.utils import is_fbcode
from torchao.prototype.quant_llm import from_scaled_tc_fpx
import pytest

if is_fbcode():
pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels")

try:
import torchao.ops
except RuntimeError:
pytest.skip("torchao.ops not available")


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace)
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
class TestOps(TestCase):
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(OC).half() + 0.5
fp16_activation = torch.rand(BS, IC).half() + 0.5
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_llm_linear(self):
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
scale = torch.rand(OC).half() + 0.5
fp16_act = torch.rand(BS, IC).half() + 0.5
return fpx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear(self, ebits, mbits):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")

# smoke test
torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")

results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half()
results_fp16 = fp16_act @ fp16_weight.T

results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
error = (results_fpx - results_fp16).abs().mean()
gt = results_fp16.abs().mean()
relative_error = error / gt
assert relative_error < 1e-3

fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = fp16_activation @ fp16_weight.T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
assert relative_error.mean() < 1e-2
instantiate_parametrized_tests(TestOps)


if __name__ == "__main__":
unittest.main()
run_tests()
Loading
Loading