Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MX] Support mixed MXFP4/FP6/FP8 linear layer #1666

Open
balancap opened this issue Feb 5, 2025 · 1 comment
Open

[MX] Support mixed MXFP4/FP6/FP8 linear layer #1666

balancap opened this issue Feb 5, 2025 · 1 comment

Comments

@balancap
Copy link
Contributor

balancap commented Feb 5, 2025

Blackwell hardware natively supports any combination of MXFP4/FP6/FP8 in matmuls. See PTX and Cutlass documentation:

According the MX paper, and macore generally the large quantization literature, there is advantages to use different bitwidth for weights, activations and gradients. It would be very useful in mx_mm and MXLinear to support this more general setting.

@balancap
Copy link
Contributor Author

balancap commented Feb 5, 2025

#1667 is adding mixed element dtypes support to mx_mm. A similar general interface could be added to MXLinear

class MXLinear(torch.nn.Linear):
    @classmethod
    @torch.no_grad()
    def from_float(cls, mod, in_elem_dtype, w_elem_dtype, grad_elem_dtype, block_size):
        ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant