-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]: Low_Level_Zero plugin crashes with LoRA #5909
Comments
Hey @Fallqs thanks for reporting the bug and I will look into this. Btw will it be possible to share the code you are using or a min repro for the LoRA crash? |
Sorry to bother you, could you please describe it in more detail? Because I am using the 0.3.6 version of colossalai, I put the following code in the corresponding position according to your code implementation, but it didn't work. Is it because I put it in the wrong position?i also want to use lora tuning. this is my code:
this is my issue: |
Please share a minimum script to reproduce the error. Your code is wrong as _run_reduction reduces grads for all bucketed parameters.
|
Thank you for your reply. Regarding the above issue, I have found that my code was added in the wrong location. He used line 808 from version 0.4.1. Then I have a new question:
When I tried to use p.grad, I found that an error occurred. After checking, I found that after using colorssalAI, I cannot directly access the gradient using p.grad. So the question is, how can we obtain gradient information? [rank0]: Traceback (most recent call last): This is the website I searched for: Thank you again for your enthusiastic response. |
You can get the grads this way by calling |
I have read the above code before, but it did not involve zero_optizer in my code implementation. Can you be more specific on how to implement it? |
Does your training code involve an optimizer? That's what you're looking for |
Sorry to bother you again, I will refine my question. The following is a minimal reproduction of my problem. However, it involves several methods of opensora that need to be imported. I used the code optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(p)) mentioned above to try to access the gradient of the parameters, but I did not get any value. The optimizer I used here is HybridAdam, and I used Booster which is not used in the link you gave. My question is how can I get the gradient in the case of the above code?
This is my run command: I also tried a code that can successfully obtain the gradient, as follows:
I don't know what the difference is between the two. I think the difference is that one uses booster.backward(loss, optimizer) and the other uses loss.backward() to pass the gradient back. Is it possible that I can't get the gradient when I use bosster, or is there something wrong with my code? |
hey @281LinChenjian , Regarding the problem you've got: Code snippet 1here after optimizer updates the param, it clears the Code snippet 2Don't call
Since there is no universal API for gradient accessing it might be a bit tricky and confusing, do feel free to ask here or open another issue if you still have problem :) |
Thank you for your generous help. I have thoroughly understood how to use this method. Thank you again for your patient answers!!! |
Sorry to bother you again, I found that when training with multiple graphics cards, the shape obtained by
This is the error when training with two graphics cards: This is the error when training with four graphics cards: Interestingly, 3456×1152=11990656×2,and 3456×1152=995328×4. |
ZeRO splits gradients evenly across devices |
Is there any way to integrate them together, or is there any way to get the corresponding gradients on different graphics cards? |
@281LinChenjian I guess you'll have to manually do For your reference |
Is there an existing issue for this bug?
🐛 Describe the bug
The line 808 of
zero/low_level/low_level_optim.py
assumes that every single parameter in model.parameters() is trainable. However, this is not true when it comes to LoRA tuning, resulting in training crashes.To solve this issue, you may just add a shortcut below this
for-loop
:Environment
CUDA 12.1
PyTorch 2.1.2
ColossalAI 0.4.0 [This BUG is not observed in 0.3.5]
The text was updated successfully, but these errors were encountered: