We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.get_autocast_gpu_dtype()
When I use:
from torchao.float8.float8_linear import Float8Linear, Float8LinearConfig config = Float8LinearConfig() self.w1 = Float8Linear(dim, hidden_dim, bias=False, config=config)
I get:
/root/prime/.venv/lib/python3.10/site-packages/torchao/float8/float8_linear.py:430: DeprecationWarning: torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:787.) autocast_dtype = torch.get_autocast_gpu_dtype()
It also seems like the logic around autocasting around that point is incomplete in general. Lacking a CPU implementation. https://github.com/pytorch/ao/blob/main/torchao/float8/float8_linear.py#L312
The text was updated successfully, but these errors were encountered:
Thanks for flagging! #1528 should fix the warning. I believe that float8 is only supported by GPUs currently. cc @vkuzo @drisspg if I'm mistaken
Sorry, something went wrong.
No branches or pull requests
When I use:
I get:
It also seems like the logic around autocasting around that point is incomplete in general. Lacking a CPU implementation.
https://github.com/pytorch/ao/blob/main/torchao/float8/float8_linear.py#L312
The text was updated successfully, but these errors were encountered: