Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor subclass methods for DTensor and FSDP2 #1664

Open
jeromeku opened this issue Feb 5, 2025 · 3 comments
Open

Tensor subclass methods for DTensor and FSDP2 #1664

jeromeku opened this issue Feb 5, 2025 · 3 comments
Labels
question Further information is requested

Comments

@jeromeku
Copy link
Collaborator

jeromeku commented Feb 5, 2025

Is there a protocol / interface that a tensor subclass must implement in order to be used with DTensor primitives and for training with FSDP2?

I've been walking through NF4 as an example as it covers both. However, the methods are scattered across __torch_function__ and __torch_dispatch__ (though the unittests make it clear which ops are tested for FSDP).

Is there a cleaner / expected format for subclassing a tensor such that

  • it can be used with DTensor collectives and FSDP2, and
  • composed with subclass-specific overrides for streamlined use with torch.compile?

@msaroufim @awgu @weifengpy @jerryzh168


p.s. Fwiw, also looked at the developer-guide tensor subclass example but found the abstractions a bit hard to follow; would personally prefer using torch-native functionalities.

@gau-nernst
Copy link
Collaborator

I worked on some tensor subclasses that can work with DTensor+FSDP2. You might find this useful (more concise than NF4)

https://github.com/pytorch/ao/blob/v0.8.0/torchao/prototype/quantized_training/int8.py

@drisspg drisspg added the question Further information is requested label Feb 5, 2025
@drisspg
Copy link
Contributor

drisspg commented Feb 5, 2025

@weifengpy Do we have any documentation on this?

@weifengpy
Copy link
Contributor

weifengpy commented Feb 5, 2025

However, the methods are scattered across torch_function and torch_dispatch

@jeromeku I always prefer __torch_dispatch__ and it should be enough. NF4 has legacy __torch_dispatch__ implementations for single device. When I extend NF4 to FSDP2, I have to use __torch_function__ to be backward compatible

Is there a cleaner / expected format for subclassing a tensor such that it can be used with DTensor collectives and FSDP2

For FSDP2 + NF4, as you mentioned, requried tensor ops are defined in TestFSDPOps

class TestFSDPOps(TestCase):
. I implemented those tensor ops one by one because NF4 contains many attributes (scalars or small tensors) that are not shardable. I have to manually define the behavior for each op

For FSDP2 + float8, it's simply dispatching every tensor op to inner tensors with 3 lines of code:

args, kwargs = pytree.tree_map_only(
WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
)
out = func(*args, **kwargs)

For DTensor, if you try DTensor(local_tensor=your_tensor_subclass) and call .full_tensor(), you should be able to see unimpelmented tensor op related to all-gather. I guess this is required for state dict. I did not have an example because we issue collectives on local_tensors directly in TorchTune

https://github.com/pytorch/torchtune/blob/9475b5adab6aa2746b08c73059ca9af9f791559a/torchtune/training/_distributed.py#L259

Do we have any documentation on this?

@drisspg I commented with examples above but not official document yet. Agree we should document this better. will think about it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants