Skip to content

Commit

Permalink
Remove optimizer step on initialization (#5104)
Browse files Browse the repository at this point in the history
All ZeRO 1/2/3 stages call the optimizer's `step()` on its
initialization. This increments a counter in the optimizer and produces
a different result in parameter update with the normal usage of PyTorch.
This PR eliminates `step()` in the initialization and lazily configures
some internal states (linking *hp_params*) after the first `step()`
call.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
tohtana and tjruwase authored Feb 11, 2024
1 parent 25a0204 commit 1817980
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 95 deletions.
16 changes: 12 additions & 4 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, fragment_address
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down Expand Up @@ -165,6 +165,7 @@ def _setup_for_real_optimizer(self):

# Need optimizer states initialized before linking lp to optimizer state
self._link_all_hp_params()
self._hp_optimizer_states_linked = False
self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()

Expand Down Expand Up @@ -199,9 +200,15 @@ def _link_all_hp_params(self):
param_group_index=i,
partition_start=partition_id * partition_size,
partition_size=partition_size,
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
dp_group=self.real_dp_process_group[i])

def _lazy_init_hp_params_optimizer_state(self):
if not self._hp_optimizer_states_linked:
for i, _ in enumerate(self.optimizer.param_groups):
lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i],
self.optimizer.state)
self._hp_optimizer_states_linked = True

def initialize_optimizer_states(self):
"""Take an optimizer step with zero-valued gradients to allocate internal
optimizer state.
Expand All @@ -215,8 +222,6 @@ def initialize_optimizer_states(self):
param_partition.grad = grad_partition.to(
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition

self.optimizer.step()

if self.grad_acc_dtype is not torch.float32:
for param_partition in self.fp32_groups_flat_partition:
param_partition.grad = None
Expand Down Expand Up @@ -263,6 +268,9 @@ def step(self, closure=None):

self.optimizer.step()

# We need to link optimizer state after the first step() call
self._lazy_init_hp_params_optimizer_state()

self.update_lp_params()

self.clear_hp_grads()
Expand Down
4 changes: 0 additions & 4 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,10 +1015,6 @@ def initialize_optimizer_states(self):
else:
self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements)

# Initialize the optimizer states with the flattened fp32 partition.
if not is_adagrad:
self._optimizer_step(i)

if swappable_param_subgroup:
self._partitioned_params_swap_out(i)

Expand Down
70 changes: 50 additions & 20 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
from deepspeed.utils import link_hp_params
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.checkpoint import enable_universal_checkpoint

from deepspeed.utils import groups
Expand Down Expand Up @@ -88,6 +88,12 @@ def _get_padded_tensor(src_tensor, size):
return padded_tensor


def _pad_tensor_by_size(src_tensor, pad_size, dtype, device):
padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device)
padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data)
return padded_tensor


class DeepSpeedZeroOptimizer(ZeROOptimizer):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
Expand Down Expand Up @@ -536,6 +542,8 @@ def __init__(self,
see_memory_usage(f"After initializing ZeRO optimizer", force=True)

self._link_all_hp_params()
self._hp_optimizer_states_linked = False

self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()

Expand Down Expand Up @@ -578,9 +586,15 @@ def _link_all_hp_params(self):
param_group_index=i,
partition_start=partition_id * partition_size,
partition_size=partition_size,
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
dp_group=self.real_dp_process_group[i])

def _lazy_init_hp_params_optimizer_state(self):
if not self._hp_optimizer_states_linked:
for i, _ in enumerate(self.optimizer.param_groups):
lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i],
self.optimizer.state)
self._hp_optimizer_states_linked = True

def is_moe_group(self, group):
return 'moe' in group and group['moe']

Expand Down Expand Up @@ -664,8 +678,6 @@ def initialize_optimizer_states(self):
# which do lazy initialization of the state at the first call to step.
if isinstance(self.optimizer, torch.optim.Adagrad):
self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
else:
self.optimizer.step()

if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
Expand Down Expand Up @@ -1793,6 +1805,9 @@ def _optimizer_step(self, group_no):
self.optimizer.step()
self.optimizer.param_groups = original_param_groups

# We need to link optimizer state after the first step() call
self._lazy_init_hp_params_optimizer_state()

def step(self, closure=None):
"""
Not supporting closure.
Expand Down Expand Up @@ -2208,19 +2223,39 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, group
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return all_partition_states[0]

def _restore_base_optimizer_state(self, base_optimizer_group_states):
def _restore_step_from_elastic_checkpoint(self, all_state_dict):
assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]
assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
for sd in all_state_dict), "State dicts of all partitions must have the same step value"
return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]

def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings):
if type(base_optimizer_group_states) == dict:
base_optimizer_group_states = base_optimizer_group_states['state']

saved_keys = base_optimizer_group_states[0].keys()

for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
if torch.is_tensor(self.optimizer.state[p][key]):
dst_tensor = self.optimizer.state[p][key]
src_tensor = _get_padded_tensor(saved, dst_tensor.numel())
self.optimizer.state[p][key].data.copy_(src_tensor.data)
padding = 0 if group_paddings is None else group_paddings[i]
for key in saved_keys:
saved = base_optimizer_group_states[i][key]

if torch.is_tensor(saved):
if key in self.optimizer.state[p]:
dst_tensor = self.optimizer.state[p][key]
src_tensor = _get_padded_tensor(saved, dst_tensor.numel())
self.optimizer.state[p][key].data.copy_(src_tensor.data)
else:
self.optimizer.state[p][key] = _pad_tensor_by_size(
saved, padding, torch.float32,
torch.device('cpu') if self.cpu_offload else self.device)
else:
self.optimizer.state[p][key] = saved

for param_group in self.optimizer.param_groups:
param_group['step'] = base_optimizer_state_step

def get_ep_ranks(self, rank=0, group_name=None):
from deepspeed.utils import groups
expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name)
Expand Down Expand Up @@ -2248,15 +2283,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict):
partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i)
base_optimizer_group_states.append(partition_states)

self._restore_base_optimizer_state(base_optimizer_group_states)

# Restore step
if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]:
assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
for sd in all_state_dict), "State dicts of all partitions must have the same step value"
loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
for param_group in self.optimizer.param_groups:
param_group['step'] = loaded_param_groups_step
self._restore_base_optimizer_state(base_optimizer_group_states,
self._restore_step_from_elastic_checkpoint(all_state_dict), None)

def load_state_dict(self,
state_dict_list,
Expand Down Expand Up @@ -2368,7 +2396,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self._restore_elastic_base_optimizer_state(state_dict_list)
else:
# loading an elastic checkpoint into rigid exec
self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE])
self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE],
current_rank_sd[BASE_OPTIMIZER_STATE_STEP],
current_rank_sd[GROUP_PADDINGS])

# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
from .mixed_precision_linkage import link_hp_params
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
10 changes: 8 additions & 2 deletions deepspeed/utils/mixed_precision_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@


def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group):
param_group_index, partition_start, partition_size, dp_group):
local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group)

for lp_param, lp_start in local_lp_param_and_offset:
lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict,
offload_gradient_dict, use_offload, param_group_index,
partition_start, partition_size, partition_optimizer_state)
partition_start, partition_size)


def lazy_init_hp_params_optimizer_state(lp_param_list, flat_hp_partition, optimizer_state):
for lp in lp_param_list:
if lp._hp_mapping is not None:
lp._hp_mapping.set_optim_state_fragment(flat_hp_partition, optimizer_state[flat_hp_partition])


def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):
Expand Down
17 changes: 9 additions & 8 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class tensor_fragment:
lp_fragment_address: fragment_address
hp_fragment: torch.Tensor
hp_fragment_address: fragment_address
optim_fragment: Dict
gradient_dict: Dict
offload_gradient_dict: Dict
use_offload: bool
param_group_index: int
optim_fragment: Dict = None

def update_hp(self):
self.hp_fragment.data.copy_(self.lp_fragment.data)
Expand All @@ -39,6 +39,13 @@ def get_optim_state_fragment(self, key):
else:
raise ValueError(f'{key} not found in optimizer state fragment')

def set_optim_state_fragment(self, flat_hp_partition, optim_fragment):
self.optim_fragment = {
key: value.narrow(0, self.hp_fragment_address.start, self.hp_fragment_address.numel)
for key, value in optim_fragment.items()
if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
}

def get_hp_fragment_address(self):
return self.hp_fragment_address

Expand Down Expand Up @@ -255,7 +262,7 @@ def safe_set_local_fp32_param(param, value):


def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
param_group_index, partition_start, partition_size, optimizer_state_dict):
param_group_index, partition_start, partition_size):
lp_end = lp_param.numel() + lp_start
hp_start = partition_start
hp_end = partition_start + partition_size
Expand All @@ -268,11 +275,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict
fragment_numel = fragment_end - fragment_start
hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
optim_fragment = {
key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel)
for key, value in optimizer_state_dict.items()
if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
}

lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
Expand All @@ -281,7 +283,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict
lp_fragment_address=lp_frag_address,
hp_fragment=hp_fragment_tensor,
hp_fragment_address=hp_frag_address,
optim_fragment=optim_fragment,
gradient_dict=gradient_dict,
offload_gradient_dict=offload_gradient_dict,
use_offload=use_offload,
Expand Down
36 changes: 24 additions & 12 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,11 @@ class TestZeroAdamOptimizerStepCount(DistributedTest):
world_size = 1

def test(self, zero_stage):
# We verify trhee conditions:
# 1. global_steps starts at 0
# 2. All subgroups have the same step count
# 3. The global step count is the same as the step count of the first subgroup

# force all params to be partitioned by forcing threshold=0
config_dict = {
"train_micro_batch_size_per_gpu": 2,
Expand Down Expand Up @@ -1399,24 +1404,31 @@ def test(self, zero_stage):
model_parameters=model.parameters())
data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)

for i, batch in enumerate(data_loader):
assert model.global_steps == 0

for batch in data_loader:
loss = model(batch[0], batch[1])
model.backward(loss)

is_gradient_accumulation_boundary = model.is_gradient_accumulation_boundary()
model.step()

step_counts = []
if zero_stage == 3:
for sub_group_id, _ in enumerate(optimizer.fp16_groups):
fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id]
state = optimizer.optimizer.state[fp32_param]
step_counts.append(state["step"])
assert all(step == step_counts[0] for step in step_counts)
elif zero_stage == 1 or zero_stage == 2:
for param_group in optimizer.optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.optimizer.state[param]
if is_gradient_accumulation_boundary:
step_counts = []

if zero_stage == 3:
for sub_group_id, _ in enumerate(optimizer.fp16_groups):
fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id]
state = optimizer.optimizer.state[fp32_param]
step_counts.append(state["step"])
elif zero_stage == 1 or zero_stage == 2:
for param_group in optimizer.optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.optimizer.state[param]
step_counts.append(state["step"])

assert all(step == step_counts[0] for step in step_counts)
assert model.global_steps == step_counts[0]


@pytest.mark.parametrize("zero_stage", [1, 2, 3])
Expand Down
Loading

0 comments on commit 1817980

Please sign in to comment.