-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NF4Tensor and DDP #1665
Comments
cc @drisspg @weifengpy any guidance on this? |
curious about the use case here. is it finetuning/QLoRA on a llama/transformer-alike model? we didnt support DDP + NF4 because llama/qlora are always parallelized by FSDP first. would you consider FSDP? for DDP, we may need to design a tensor subclass extension point, if the use case is motivating |
Yes exactly, using QLoRA finetuning with DDP. |
See this PR which implements the necessary function to run DDP with NF4Tensor. Note that a much simpler fix is to have E.g., from torch.nn.parallel import DistributedDataParallel as DDP
def get_params_to_ignore(model):
params_to_ignore = []
for name, param in model.named_parameters():
if isinstance(param, NF4Tensor):
params_to_ignore.append(name)
return params_to_ignore
model = QLoRAModel(...)
params_to_ignore = get_params_to_ignore(model)
DDP._set_params_and_buffers_to_ignore_for_model(model, params_to_ignore) This will exclude |
@jeromeku awesome thanks! And thanks for the pointer to |
I am trying to use
NF4Tensor
weights in my model and wrap it withDistributedDataParallel
, but get the following error:To replicate:
torchrun --nproc_per_node=2 script.py
NotImplementedError: NF4Tensor dispatch: attempting to run c10d.broadcast_.default, this is not supported
Is there some way around this issue?
The text was updated successfully, but these errors were encountered: