We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I'm 90% sure that the current implementation of PER does not correctly prevent partially broken/overwritten trajectories from being sampled.
I'm not positive, but something like this provides a small indication:
@pytest.mark.parametrize("length", [9]) @pytest.mark.parametrize("add_batch_size", [16]) @pytest.mark.parametrize("sample_sequence_length", [3, 5]) @pytest.mark.parametrize("period", [1]) @pytest.mark.parametrize("max_length_time_axis", [20]) def test_prioritised_sample_doesnt_sample_prev_broken_trajectories( length: int, add_batch_size: int, sample_sequence_length: int, period: int, max_length_time_axis: int, ) -> None: """Test to ensure that `sample` avoids including rewards from broken trajectories. """ fake_transition = {"reward": jnp.array([1])} offset = jnp.arange(add_batch_size).reshape(add_batch_size, 1, 1) * 1000 buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer( add_batch_size=add_batch_size, sample_batch_size=2048, sample_sequence_length=sample_sequence_length, period=period, max_length_time_axis=max_length_time_axis, min_length_time_axis=sample_sequence_length, ) rng_key = jax.random.PRNGKey(0) state = buffer.init(fake_transition) for i in range(10): fake_batch_sequence = { "reward": jnp.arange(length) .reshape(1, length, 1) .repeat(add_batch_size, axis=0) + offset + length * i } state = buffer.add(state, fake_batch_sequence) rng_key, rng_key1 = jax.random.split(rng_key) sample = buffer.sample(state, rng_key1) sampled_r = sample.experience["reward"] for b in range(sampled_r.shape[0]): assert is_strictly_increasing(sampled_r[b])
Numbers are always increasing, I think.
All I know right now is that it's very possible to sample the broken sequences.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Describe the Bug
I'm 90% sure that the current implementation of PER does not correctly prevent partially broken/overwritten trajectories from being sampled.
To Reproduce
I'm not positive, but something like this provides a small indication:
Expected Behavior
Numbers are always increasing, I think.
Outcome
All I know right now is that it's very possible to sample the broken sequences.
The text was updated successfully, but these errors were encountered: