-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FP5 E2M2 support from upstream #399
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/399
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4e585e9 with merge base c2cf973 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
For the nightly failure in CI make sure to rebase to main in ao. There's some upstream issue with triton we haven't debugged yet #429 EDIT: Fix was merged |
|
||
|
||
class QuantLlmLinearWeight(Tensor): | ||
_implements = classmethod(_implements) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha this is clever
nit: we can just do implements = classmethod(_implements)
I think, implements
can be public classmethod
tensor subclass changes LGTM, I'll leave the rest to @msaroufim. I'll also take a closer look at the rest of the code tomorrow |
another nit comment I have is the |
I also don't really like the Quant-LLM name since it's quite ambiguous. But the upstream repo renames it to that (https://github.com/usyd-fsalab/fp6_llm) so I follow it. The original name (first release) was FP6-LLM. The kernel supports arbitrary FP2 -> FP7, so naming it as only FP6/FP5 is quite limited. My benchmarks show that FP4 is not competitive with INT4 tinygemm (might be improved by block-wise quantization). For now I only enable FP6 E3M2, FP6 E2M3, FP5 E2M2, and FP5 E3M1 (each dtype support is a template instantiation. upstream repo only tested for FP6E3M2 and F5 E2M2). Something better might be like I'm open to changing the name.
The FPx dtype itself is quite similar to MX, thus I refactored dtype casting code from @vkuzo (#363) and re-use it here. However, there are key differences
I think those differences are significant enough to make separate subclasses for them. It also helps with maintenance. |
Thanks for the detailed context @gau-nernst, OK makes sense to keep the name for now I guess, but maybe we can talk to the author to make the name a bit more descriptive like you said (e.g. quant_llm_fpx), in torchao I feel we should probably use |
* first update from upstream * add some primitives to support fp5 * binding for ExMy * add QuantLlmLinear * fix * update README * update README * remove fp6_linear from C++ * fix * fix * fix * update * add more experimental config * update * add from tc_fpx * remove redundant code * fix import * fix test * avoid division by 0 * add subclass. use uint8 * subclass API * update doc * remove unused op * update * rename. update * update docs * rename * fix for PyTorch 2.2 * _implements -> implements * set CUDA context * fix __repr__
usyd-fsalab/fp6_llm@5df6737
Also close #402
New API
Benchmark results
Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in
_models/llama
. tokens/s is measured usinggenerate.py
which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured usingeval.py
which uses lm_eval. The model used is meta-llama/Llama-2-7b-chat-hf.FPx quantization is run with
--precision float16
. The rest uses the default precision ofbfloat16
.