diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 1dcff3f1c12f..eaab8d30b1b1 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -187,16 +187,18 @@ def reset_step(self) -> None: f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") if not self.is_complete_trace(): # not self.trace_complete: - # Make sure that recorded parameter and submodule orders are - # identical across ranks + # Make sure that recorded submodule orders are identical across ranks assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) - assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) - assert_ints_same_as_other_ranks( - [p.step_id_last_used_at for p in self.__param_order]) if self.is_record_trace(): # Successfully recorded a trace self.construct_parameter_trace_from_module_trace() + # Make sure that recorded parameter orders are identical across ranks + assert_ints_same_as_other_ranks( + [p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks( + [p.step_id_last_used_at for p in self.__param_order]) + self.__submodule_order = tuple(self.__submodule_order) # freeze self.__param_order = tuple(self.__param_order) # freeze self.__trace_mode = ZeRoTraceMode.COMPLETE