From 5acbfa5bc42681ddf060829dc8cf389e3367e667 Mon Sep 17 00:00:00 2001 From: yisheng Date: Wed, 7 Aug 2024 22:39:02 -0700 Subject: [PATCH] add nan check to support bfloat16 --- deepspeed/runtime/zero/stage3.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3ac6987e9c22..2c6fd6d95308 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1963,7 +1963,7 @@ def _overflow_clean_up(self, prev_scale): def _overflow_check_and_loss_scale_update(self): # First compute norm for all group so we know if there is overflow - if self.dtype == torch.float16: + if self.dtype in [torch.float16, torch.bfloat16]: self.check_overflow() #loss scaling related computation diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 461ad719536a..0282c55f9da4 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1826,7 +1826,7 @@ def step(self, closure=None): see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow - if self.dtype == torch.float16: + if self.dtype in [torch.float16, torch.bfloat16]: self.check_overflow() prev_scale = self.loss_scale