Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float8Linear autocast's torch.get_autocast_gpu_dtype() is deprecated #1522

Open
apaz-cli opened this issue Jan 7, 2025 · 1 comment
Open
Labels

Comments

@apaz-cli
Copy link

apaz-cli commented Jan 7, 2025

When I use:

from torchao.float8.float8_linear import Float8Linear, Float8LinearConfig
config = Float8LinearConfig()
self.w1 = Float8Linear(dim, hidden_dim, bias=False, config=config)

I get:

  /root/prime/.venv/lib/python3.10/site-packages/torchao/float8/float8_linear.py:430: DeprecationWarning: torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:787.)
    autocast_dtype = torch.get_autocast_gpu_dtype()

It also seems like the logic around autocasting around that point is incomplete in general. Lacking a CPU implementation.
https://github.com/pytorch/ao/blob/main/torchao/float8/float8_linear.py#L312

@jcaip
Copy link
Contributor

jcaip commented Jan 8, 2025

Thanks for flagging! #1528 should fix the warning. I believe that float8 is only supported by GPUs currently. cc @vkuzo @drisspg if I'm mistaken

@jcaip jcaip added the triaged label Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants