Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP5 E2M2 support from upstream #399

Merged
merged 33 commits into from
Jun 25, 2024
Merged

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Jun 19, 2024

usyd-fsalab/fp6_llm@5df6737

Also close #402

New API

from torchao.quantization.quant_api import quantize
from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only

model = ...
model.half()  # not necessary, but recommeneded to maintain accuracy
quantize(model, fp6_llm_weight_only())  # convert nn.Lineaer.weight to FP6 E3M2 in-place

# for generic FPx EyMz where x = 1 + y + z
# quantize(model, quant_llm_fpx_weight_only(2, 2))  # use FP5 E2M2 instead

# fully compatible with torch.compile()
model.compile(mode="max-autotune", fullgraph=True)

Benchmark results

Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in _models/llama. tokens/s is measured using generate.py which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using eval.py which uses lm_eval. The model used is meta-llama/Llama-2-7b-chat-hf.

FPx quantization is run with --precision float16. The rest uses the default precision of bfloat16.

Quantization wikitext perplexity tokens/s
INT8 12.21 87.45
INT4-256 (tinygemm) (bug) 157.10
FP6 E3M2 12.34 106.76
FP6 E2M3 12.23 106.77
FP5 E3M1 12.55 122.69
FP5 E2M2 12.47 122.66
FP4 E3M0 14.58 145.55
FP4 E2M1 15.01 146.05
FP3 E2M0 74625.18 164.49

Copy link

pytorch-bot bot commented Jun 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/399

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4e585e9 with merge base c2cf973 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 19, 2024
@gau-nernst gau-nernst marked this pull request as ready for review June 24, 2024 15:07
@gau-nernst gau-nernst requested a review from msaroufim June 24, 2024 15:07
@msaroufim msaroufim requested a review from jerryzh168 June 24, 2024 15:12
@msaroufim
Copy link
Member

msaroufim commented Jun 24, 2024

For the nightly failure in CI make sure to rebase to main in ao. There's some upstream issue with triton we haven't debugged yet #429

EDIT: Fix was merged



class QuantLlmLinearWeight(Tensor):
_implements = classmethod(_implements)
Copy link
Contributor

@jerryzh168 jerryzh168 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha this is clever

nit: we can just do implements = classmethod(_implements) I think, implements can be public classmethod

@msaroufim msaroufim mentioned this pull request Jun 24, 2024
7 tasks
@jerryzh168
Copy link
Contributor

jerryzh168 commented Jun 25, 2024

tensor subclass changes LGTM, I'll leave the rest to @msaroufim. I'll also take a closer look at the rest of the code tomorrow

@jerryzh168
Copy link
Contributor

another nit comment I have is the quant_llm name, should we rename this to something more specific like fp6/fp5? also what is the difference between this as mx?

@gau-nernst
Copy link
Collaborator Author

another nit comment I have is the quant_llm name, should we rename this to something more specific like fp6/fp5?

I also don't really like the Quant-LLM name since it's quite ambiguous. But the upstream repo renames it to that (https://github.com/usyd-fsalab/fp6_llm) so I follow it. The original name (first release) was FP6-LLM.

The kernel supports arbitrary FP2 -> FP7, so naming it as only FP6/FP5 is quite limited. My benchmarks show that FP4 is not competitive with INT4 tinygemm (might be improved by block-wise quantization). For now I only enable FP6 E3M2, FP6 E2M3, FP5 E2M2, and FP5 E3M1 (each dtype support is a template instantiation. upstream repo only tested for FP6E3M2 and F5 E2M2). Something better might be like quant_llm_fpx. In the future upstream repo might even support INTx.

I'm open to changing the name.

also what is the difference between this as mx?

The FPx dtype itself is quite similar to MX, thus I refactored dtype casting code from @vkuzo (#363) and re-use it here. However, there are key differences

  • MX specs only specify FP6 E3M2, FP6 E2M3 and FP4 E2M1, while Quant-LLM kernel supports arbitrary FPx as mentioned above
  • MX dtype use block-wise quantization scale, and scale is in E8M0 format. In contrast, Quant-LLM kernel only supports per-row quantization scale.
  • Quant-LLM uses a special layout optimized for tensor cores (see pack_tc_fpx()). The idea is similar to tinygemm I think.

I think those differences are significant enough to make separate subclasses for them. It also helps with maintenance.

@jerryzh168
Copy link
Contributor

Thanks for the detailed context @gau-nernst, OK makes sense to keep the name for now I guess, but maybe we can talk to the author to make the name a bit more descriptive like you said (e.g. quant_llm_fpx), in torchao I feel we should probably use fpx (as a name for dtype) in the end to indicate that it supports different floating point dtypes.

@msaroufim msaroufim merged commit 70aef5d into pytorch:main Jun 25, 2024
13 checks passed
@gau-nernst gau-nernst deleted the fp5_llm branch June 25, 2024 21:30
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* first update from upstream

* add some primitives to support fp5

* binding for ExMy

* add QuantLlmLinear

* fix

* update README

* update README

* remove fp6_linear from C++

* fix

* fix

* fix

* update

* add more experimental config

* update

* add from tc_fpx

* remove redundant code

* fix import

* fix test

* avoid division by 0

* add subclass. use uint8

* subclass API

* update doc

* remove unused op

* update

* rename. update

* update docs

* rename

* fix for PyTorch 2.2

* _implements -> implements

* set CUDA context

* fix __repr__
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Migrate FP6-LLM implementation from module to tensor subclass
4 participants