-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds support for knowledge distillation #380
base: main
Are you sure you want to change the base?
Conversation
There are many different forms of model training which exist. One popular form of training is knowledge distillation, where a student model learns the output distributions from a teacher model. This commit introduces support for knowledge distillation in the training library. This commit also exposes the `weight_decay` hyperparameter which is often used to help deep learning models generalize. Lastly, this commit changes the useage from `torch.distributed` to just `dist`, as it is a common module used throughout the codebase. Signed-off-by: Oleg S <[email protected]>
|
||
temperature: float = Field(1.0, gt=0.0) | ||
alpha: float = Field(1.0, le=1.0, ge=0.0) | ||
teacher_path: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if possible would love to standardize on using pathlib.Path rather than str paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JamesKunstle I see your point, would it make sense for it to be a path when it can also take on a HF reference? I understand that references can technically still be paths, but to a consumer reading it might sound like only local models are accepted. Would str | Path
be satisfactory?
teacher_model = AutoModelForCausalLM.from_pretrained( | ||
model_name_or_path, torch_dtype=torch.bfloat16 | ||
).to(device) | ||
model_dev = next(teacher_model.parameters()).device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're calling .to(device)
just above, could you make a note of why you also need to confirm the device locale below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes of course.
src/instructlab/training/config.py
Outdated
weight_decay: float = Field(0.0, ge=0.0) | ||
|
||
# settings for knowledge distillation | ||
distillation_options: Optional[DistillationConfig] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've seen that Optional[DistillationConfig]
syntax is replaced by DistillationConfig | None
in recent Pythonic parlance once the optional annotation was added to the language. This is a nit, not required to change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do the proposed method to be more consistent with how Python expects optionals in the future.
loss = None | ||
if args.distill: | ||
# teacher_model should always be provided when `args.distill` is enabled | ||
if TYPE_CHECKING: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is supposed to be a runtime check but TYPE_CHECKING is always False at runtime.
https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should fail much earlier if distillation is set but no teacher_model is provided, like before we do any data preprocessing or fire up the GPUs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it is, I believe I had errors with type-checking here though and it not knowing that teacher_model
is properly set.
), "teacher model cannot be None when `distill` is enabled" | ||
|
||
with torch.no_grad(): | ||
teacher_output: CausalLMOutput = teacher_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You turn off requires_grad
on all the params in the teacher_model. You could just be doing this instead, I think this gives the same output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So they're not fully the same. requires_grad
ensures that a tensor never needs its gradient to be computed when .backward()
is called at some point in the computation graph and therefore doesn't need to store any additional data for it. Whereas torch.no_grad
ensures that the tensor computations within the given context do not count towards the gradient calculation during backprop.
The reason we're doing both is so that:
requires_grad=False
--> The teacher model doesn't need to get updated so we don't need to store any additional variableswith torch.no_grad()
--> If any other tensors happen to participate in the computation for whatever reason, say for example someone updates this and includes them, their gradients are also not impacted by participating in this calculation.
Having this as an explicit context also allows us to communicate to other developers in the future that this is not intended to participate in backprop, which comes to us at no extra cost really.
Probably you can get away without using torch.no_grad
here, but it's just a good practice to do both.
src/instructlab/training/main_ds.py
Outdated
else: | ||
loss = output.loss | ||
|
||
assert loss is not None, "loss cannot be equal to None!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert
s are typically not preferred in comparison to runtime exceptions.
assert loss is not None, "loss cannot be equal to None!" | |
if loss is None: | |
raise ValueError("loss was None during distillation training. Something unrecoverable went wrong.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly because they can be removed with -0
when the interpreter is invoked. But we want to check non-null all the time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I can change this. I was using them as scaffolding when writing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm gonna make it not be specific to distillation training though since it's more about how we branch out. I suspect as we add other loss calculations (contrastive loss, preference tuning loss, etc.), we will start out by setting it to None
and having this final check to ensure it was set to something.
@@ -511,6 +609,9 @@ def main(args): | |||
# Third Party | |||
import yaml | |||
|
|||
if args.distill and not args.teacher_model_name_or_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this early check seems right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweet :party-cat:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nearly ready to go. Just a couple of API questions.
Signed-off-by: Oleg S <[email protected]>
There are many different forms of model training which exist. One popular form of training is knowledge distillation, where a student model learns the output distributions from a teacher model. This commit introduces support for knowledge distillation in the training library.
This commit also exposes the
weight_decay
hyperparameter which is often used to help deep learning models generalize.Lastly, this commit changes the useage from
torch.distributed
to justdist
, as it is a common module used throughout the codebase.Signed-off-by: Oleg S [email protected]