Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] Add support for blockwise fp8 quantization scheme used in DeepSeek v3 #1594

Open
danielvegamyhre opened this issue Jan 22, 2025 · 3 comments
Labels
float8 inference topic: new feature Use this tag if this PR adds a new feature

Comments

@danielvegamyhre
Copy link
Contributor

DeepSeek v3 uses a blockwise fp8 quantization strategy, where the scaling factor is computed independently for each block, rather than for each tensor/row/etc. The code is available here.

It would be useful for torchao to support this as well, for users wishing to do research or development with this same quantization strategy.

cc @drisspg @vkuzo

@danielvegamyhre danielvegamyhre added the topic: new feature Use this tag if this PR adds a new feature label Jan 22, 2025
@gau-nernst
Copy link
Collaborator

Just want to add some of my observations here. I played around abit with block-wise FP8 on my consumer GPU (4070Ti SUPER, sm89). A simple triton kernel does not perform really well, only 1.4x speedup over BF16 (for reference, row-wise FP8 is ~1.9x speedup). With dynamic quant overhead, e2e speedup won't be too attractive. (ofc optimizing for Hopper will be completely different).

Also tried block-wise INT8 (which is the main idea of JetFire). A simple triton kernel performs somewhat ok on consumer GPU (1.9x speedup over BF16, compared to row-wise INT8 is 2.9x speedup - note that INT8 matmul is 4x faster than BF16 on consumer GPUs), but on A100, couldn't get any speedup (speedup < 1). Probably because in the case of block-wise INT8, there is a dtype conversion from INT32 to FP32 when scaling MMA accumulate results, while FP8 does not.

For quantization BLOCK_SIZE_K (number of elements along K dim that share 1 scale value), I think only K<=128 would have a simple and performant implementation, since if BLOCK_SIZE_K is too big, we will use too much shared memory. Tried a few ways around it, such as loading tiles smaller than quantization BLOCK_SIZE_K, but couldn't make it fast.

@drisspg
Copy link
Contributor

drisspg commented Jan 22, 2025

We should also take a look at the new blockwise fp8 gemm added in cutlass 3.7

cc @alexsamardzic

@Degnel
Copy link

Degnel commented Feb 5, 2025

A PR has been created for this issue: PR #1668

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
float8 inference topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

No branches or pull requests

4 participants