diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 796957a4c6e5..dee32ab8d54e 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1997,7 +1997,7 @@ def _overflow_clean_up(self, prev_scale): def _overflow_check_and_loss_scale_update(self): # First compute norm for all group so we know if there is overflow - if self.dtype == torch.float16: + if self.dtype in [torch.float16, torch.bfloat16]: self.check_overflow() #loss scaling related computation diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index df7a2f83e3bc..da4f86806a33 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1828,7 +1828,7 @@ def step(self, closure=None): see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow - if self.dtype == torch.float16: + if self.dtype in [torch.float16, torch.bfloat16]: self.check_overflow() prev_scale = self.loss_scale