-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor subclass methods for DTensor
and FSDP2
#1664
Comments
I worked on some tensor subclasses that can work with DTensor+FSDP2. You might find this useful (more concise than NF4) https://github.com/pytorch/ao/blob/v0.8.0/torchao/prototype/quantized_training/int8.py |
@weifengpy Do we have any documentation on this? |
@jeromeku I always prefer
For FSDP2 + NF4, as you mentioned, requried tensor ops are defined in TestFSDPOps Line 310 in 8afd10e
For FSDP2 + float8, it's simply dispatching every tensor op to inner tensors with 3 lines of code: ao/torchao/float8/fsdp_utils.py Lines 196 to 199 in 8afd10e
For DTensor, if you try DTensor(local_tensor=your_tensor_subclass) and call
@drisspg I commented with examples above but not official document yet. Agree we should document this better. will think about it |
Is there a protocol / interface that a tensor subclass must implement in order to be used with
DTensor
primitives and for training withFSDP2
?I've been walking through
NF4
as an example as it covers both. However, the methods are scattered across__torch_function__
and__torch_dispatch__
(though the unittests make it clear which ops are tested forFSDP
).Is there a cleaner / expected format for subclassing a tensor such that
DTensor
collectives andFSDP2
, andtorch.compile
?@msaroufim @awgu @weifengpy @jerryzh168
p.s. Fwiw, also looked at the developer-guide tensor subclass example but found the abstractions a bit hard to follow; would personally prefer using torch-native functionalities.
The text was updated successfully, but these errors were encountered: