You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be nice if RolloutBufferSamples could also contain a batch of next observations (alongside a mask that, for each observation, tells us whether that observation has a successor).
Motivation
I'm implementing an RL pipeline in which I extend PPO with a custom loss. For this custom loss, I need access to (observation, next observation) pairs.
each batch of rollout data over which we compute the PPO loss is a RolloutBufferSample -- and, as these consist of a random subset of observations from the RolloutBuffer, we do not have enough information to compute the next observation for each observation in the batch.
Pitch
I have already implemented this feature and submitted it as a PR [to be linked after submission].
Alternatives
Alternatively, we could return the indices of the sampled elements with respect to the original buffer. While this may allow for more general buffer manipulation, this feels less pleasant to use.
Additional context
No response
Checklist
I have checked that there is no similar issue in the repo
The text was updated successfully, but these errors were encountered:
🚀 Feature
When sampling from a
RolloutBuffer
, we returnRolloutBufferSample
s containing tensors of observations, actions etc.stable-baselines3/stable_baselines3/common/buffers.py
Lines 473 to 479 in 69b94dd
It would be nice if
RolloutBufferSamples
could also contain a batch of next observations (alongside a mask that, for each observation, tells us whether that observation has a successor).Motivation
I'm implementing an RL pipeline in which I extend PPO with a custom loss. For this custom loss, I need access to (observation, next observation) pairs.
In the PPO implementation
stable-baselines3/stable_baselines3/ppo/ppo.py
Lines 192 to 197 in 69b94dd
each batch of rollout data over which we compute the PPO loss is a
RolloutBufferSample
-- and, as these consist of a random subset of observations from theRolloutBuffer
, we do not have enough information to compute the next observation for each observation in the batch.Pitch
I have already implemented this feature and submitted it as a PR [to be linked after submission].
Alternatives
Alternatively, we could return the indices of the sampled elements with respect to the original buffer. While this may allow for more general buffer manipulation, this feels less pleasant to use.
Additional context
No response
Checklist
The text was updated successfully, but these errors were encountered: