From aeda7f9f8c5bdaee318e7b6094279484d41c659c Mon Sep 17 00:00:00 2001 From: AGUL Date: Thu, 1 Dec 2022 02:10:50 +0800 Subject: [PATCH] Fix invalid check of recorded parameter orders in zero stage3. (#2550) Co-authored-by: Olatunji Ruwase --- .../runtime/zero/partitioned_param_coordinator.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 1dcff3f1c12f..eaab8d30b1b1 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -187,16 +187,18 @@ def reset_step(self) -> None: f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") if not self.is_complete_trace(): # not self.trace_complete: - # Make sure that recorded parameter and submodule orders are - # identical across ranks + # Make sure that recorded submodule orders are identical across ranks assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) - assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) - assert_ints_same_as_other_ranks( - [p.step_id_last_used_at for p in self.__param_order]) if self.is_record_trace(): # Successfully recorded a trace self.construct_parameter_trace_from_module_trace() + # Make sure that recorded parameter orders are identical across ranks + assert_ints_same_as_other_ranks( + [p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks( + [p.step_id_last_used_at for p in self.__param_order]) + self.__submodule_order = tuple(self.__submodule_order) # freeze self.__param_order = tuple(self.__param_order) # freeze self.__trace_mode = ZeRoTraceMode.COMPLETE