Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF4Tensor and DDP #1665

Open
psinger opened this issue Feb 5, 2025 · 6 comments · May be fixed by #1684
Open

NF4Tensor and DDP #1665

psinger opened this issue Feb 5, 2025 · 6 comments · May be fixed by #1684
Labels
question Further information is requested

Comments

@psinger
Copy link

psinger commented Feb 5, 2025

I am trying to use NF4Tensor weights in my model and wrap it with DistributedDataParallel, but get the following error:

[rank0]:     model = DistributedDataParallel(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 837, in __init__
[rank0]:     _sync_module_states(
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 313, in _sync_module_states
[rank0]:     _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 324, in _sync_params_and_buffers
[rank0]:     dist._broadcast_coalesced(
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torchao/dtypes/nf4tensor.py", line 834, in __torch_dispatch__
[rank0]:     raise NotImplementedError(
[rank0]: NotImplementedError: NF4Tensor dispatch: attempting to run aten.cat.default, this is not supported

To replicate:

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torch.nn.parallel import DistributedDataParallel
from torch import nn
import os
import torch

class NF4(nn.Module):
    
    def __init__(
        self,
        device = None,
    ):
        super().__init__()

        self.linear = nn.Linear(512, 512, bias=False, device=device)
        self.linear.weight = nn.Parameter(to_nf4(self.linear.weight))


if __name__ == "__main__":
    
    _local_rank = int(os.getenv("LOCAL_RANK", "0"))
    _device = f"cuda:{_local_rank}"

    torch.distributed.init_process_group(
        backend="nccl",
        init_method="env://",
        device_id=torch.device(_local_rank),
    )

    model = NF4(_device)

    model = DistributedDataParallel(model)

torchrun --nproc_per_node=2 script.py

NotImplementedError: NF4Tensor dispatch: attempting to run c10d.broadcast_.default, this is not supported

Is there some way around this issue?

@supriyar
Copy link
Contributor

supriyar commented Feb 5, 2025

cc @drisspg @weifengpy any guidance on this?

@supriyar supriyar added the question Further information is requested label Feb 5, 2025
@weifengpy
Copy link
Contributor

cc @drisspg @weifengpy any guidance on this?

curious about the use case here. is it finetuning/QLoRA on a llama/transformer-alike model? we didnt support DDP + NF4 because llama/qlora are always parallelized by FSDP first. would you consider FSDP? for DDP, we may need to design a tensor subclass extension point, if the use case is motivating

@psinger
Copy link
Author

psinger commented Feb 6, 2025

Yes exactly, using QLoRA finetuning with DDP.
I think FSDP is an option, but even with setting reshard_after_forward=False from my experience DDP is more efficient.

@jeromeku jeromeku linked a pull request Feb 9, 2025 that will close this issue
@jeromeku
Copy link
Collaborator

jeromeku commented Feb 9, 2025

@psinger @weifengpy

See this PR which implements the necessary function to run DDP with NF4Tensor.

Note that a much simpler fix is to have DDP ignore QLoRA params since these params don't require grads.

E.g.,

from torch.nn.parallel import DistributedDataParallel as DDP

def get_params_to_ignore(model):
    params_to_ignore = []
    for name, param in model.named_parameters():
        if isinstance(param, NF4Tensor):
            params_to_ignore.append(name)
    return params_to_ignore

  model = QLoRAModel(...)
  params_to_ignore = get_params_to_ignore(model)
  DDP._set_params_and_buffers_to_ignore_for_model(model, params_to_ignore)

This will exclude nf4 params from being synced during DDP construction, which was leading to the error you were seeing.

@psinger
Copy link
Author

psinger commented Feb 9, 2025

@jeromeku awesome thanks!
In the meantime I implemented it via BNB Params4bit - so this will be a nice cross-check also when I try having it with NF4Tensor.

And thanks for the pointer to _set_params_and_buffers_to_ignore_for_model - I tried to ignore those params but did not notice that function being available to directly do it.

@psinger
Copy link
Author

psinger commented Feb 10, 2025

@jeromeku thanks, both your solutions work, i.e. the PR and the DDP ignoring of those params

However, I am seeing nearly double the runtime vs. BNB Params4bit, is this expected? Or something I am missing.
Both use nf4 and block_size=64.
#1686

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants