diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e15cc49339ff..5c1202ba06ae 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1911,9 +1911,6 @@ def print_forward_breakdown(self, fwd_time): @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): - assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \ - f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled' - # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well @@ -1926,7 +1923,11 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): self.optimizer, 'reduce_gradients'): self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) else: - self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) + grads = None + if hasattr(self.optimizer, "get_grads_for_reduction"): + # This is currently for BF16 optimizer + grads = self.optimizer.get_grads_for_reduction() + self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) @instrument_w_nvtx def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):