Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Non-performance-impacting update] Use Pytorch DDP in ppo_atari_multigpu #495

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions cleanrl/ppo_atari_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tyro
from rich.pretty import pprint
from torch.distributions.categorical import Categorical
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

from stable_baselines3.common.atari_wrappers import ( # isort:skip
Expand Down Expand Up @@ -159,6 +160,10 @@ def get_action_and_value(self, x, action=None):
return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


def unwrap_ddp(model) -> Agent:
return model.module if isinstance(model, DDP) else model


if __name__ == "__main__":
# torchrun --standalone --nnodes=1 --nproc_per_node=2 ppo_atari_multigpu.py
# taken from https://pytorch.org/docs/stable/elastic/run.html
Expand Down Expand Up @@ -208,7 +213,7 @@ def get_action_and_value(self, x, action=None):
args.seed += local_rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

if len(args.device_ids) > 0:
Expand All @@ -228,7 +233,10 @@ def get_action_and_value(self, x, action=None):
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

agent = Agent(envs).to(device)
torch.manual_seed(args.seed)
if args.world_size > 1:
# DDP syncs gradients (after each backward step), weights are sync'd at DDP initialization
agent = DDP(agent)

optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# ALGO Logic: Storage setup
Expand Down Expand Up @@ -260,7 +268,7 @@ def get_action_and_value(self, x, action=None):

# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(next_obs)
action, logprob, _, value = unwrap_ddp(agent).get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob
Expand All @@ -282,11 +290,11 @@ def get_action_and_value(self, x, action=None):
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

print(
f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {agent.actor.weight.sum()}"
f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {unwrap_ddp(agent).actor.weight.sum()}"
)
# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
next_value = unwrap_ddp(agent).get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
Expand Down Expand Up @@ -317,7 +325,9 @@ def get_action_and_value(self, x, action=None):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]

_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
_, newlogprob, entropy, newvalue = unwrap_ddp(agent).get_action_and_value(
b_obs[mb_inds], b_actions.long()[mb_inds]
)
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()

Expand Down Expand Up @@ -357,22 +367,6 @@ def get_action_and_value(self, x, action=None):
optimizer.zero_grad()
loss.backward()

if args.world_size > 1:
# batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220
all_grads_list = []
for param in agent.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in agent.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / args.world_size
)
offset += param.numel()

nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()

Expand Down
52 changes: 11 additions & 41 deletions docs/rl-algorithms/ppo.md
Original file line number Diff line number Diff line change
Expand Up @@ -909,66 +909,36 @@ See [related docs](/rl-algorithms/ppo/#explanation-of-the-logged-metrics) for `p

[ppo_atari_multigpu.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_multigpu.py) is based on `ppo_atari.py` (see its [related docs](/rl-algorithms/ppo/#implementation-details_1)).

We use [Pytorch's distributed API](https://pytorch.org/tutorials/intermediate/dist_tuto.html) to implement the data parallelism paradigm. The basic idea is that the user can spawn $N$ processes each running a copy of `ppo_atari.py`, holding a copy of the model, stepping the environments, and averaging their gradients together for the backward pass. Here are a few note-worthy implementation details.
We use Pytorch [distributed API](https://pytorch.org/tutorials/intermediate/dist_tuto.html) and [DistributedDataParallel module](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) to implement data parallelism. The basic idea is that the user can spawn $N$ processes each running a copy of `ppo_atari.py`, holding a copy of the model, stepping the environments and learn a better policy. Here are a few note-worthy implementation details.

1. **Local versus global parameters**: All of the parameters in `ppo_atari.py` are global (such as batch size), but in `ppo_atari_multigpu.py` we have local parameters as well. Say we run `torchrun --standalone --nnodes=1 --nproc_per_node=2 cleanrl/ppo_atari_multigpu.py --env-id BreakoutNoFrameskip-v4 --local-num-envs=4`; here are how all multi-gpu related parameters are adjusted:
* **number of environments**: `num_envs = local_num_envs * world_size = 4 * 2 = 8`
* **batch size**: `local_batch_size = local_num_envs * num_steps = 4 * 128 = 512`, `batch_size = num_envs * num_steps) = 8 * 128 = 1024`
* **minibatch size**: `local_minibatch_size = int(args.local_batch_size // args.num_minibatches) = 512 // 4 = 128`, `minibatch_size = int(args.batch_size // args.num_minibatches) = 1024 // 4 = 256`
* **number of updates**: `num_iterations = args.total_timesteps // args.batch_size = 10000000 // 1024 = 9765`
1. **Adjust seed per process**: we need be very careful with seeding: we could have used the exact same seed for each subprocess. To ensure this does not happen, we do the following
1. **Adjust seed per process**: we need to be very careful with seeding: we could have used the exact same seed for each subprocess. To ensure this does not happen, we do the following

```python hl_lines="2 5 16"
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank)
torch.backends.cudnn.deterministic = args.torch_deterministic

# ...

envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

agent = Agent(envs).to(device)
torch.manual_seed(args.seed)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
```

Notice that we adjust the seed with `args.seed += local_rank` (line 2), where `local_rank` is the index of the subprocesses. This ensures we seed packages and envs with uncorrealted seeds. However, we do need to use the same `torch` seed for all process to initialize same weights for the `agent` (line 5), after which we can use a different seed for `torch` (line 16).
1. **Efficient gradient averaging**: PyTorch recommends to average the gradient across the whole world via the following (see [docs](https://pytorch.org/tutorials/intermediate/dist_tuto.html#distributed-training))

```python
for param in agent.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= world_size
torch.backends.cudnn.deterministic = args.torch_deterministic
```

However, [@cswinter](https://github.com/cswinter) introduces a more efficient gradient averaging scheme with proper batching (see :material-github: [entity-neural-network/incubator#220](https://github.com/entity-neural-network/incubator/pull/220)), which looks like:
1. **Pytorch DDP for weight and gradient synchronization**: We wrap the agent in Pytorch [DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) module as shown below:

```python
all_grads_list = []
for param in agent.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in agent.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / world_size
)
offset += param.numel()
from torch.nn.parallel import DistributedDataParallel as DDP
...
agent = Agent(envs).to(device)
if args.world_size > 1:
# DDP syncs gradients (after each backward step), weights are sync'd at DDP initialization
agent = DDP(agent)
```

In our previous empirical testing (see :material-github: [vwxyzjn/cleanrl#162](https://github.com/vwxyzjn/cleanrl/pull/162#issuecomment-1107909696)), we have found [@cswinter](https://github.com/cswinter)'s implementation to be faster, hence we adopt it in our implementation.



`DDP` uses collective communications from the torch.distributed package to synchronize gradients across all processes after each backward pass. This means that each process will have its own copy of the model, but they’ll all work together to train the model as if it were on a single machine.

### Experiment results

Expand Down
Loading