Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] PER does not correctly prevent sampling broken partial trajectories #49

Open
EdanToledo opened this issue Jan 28, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@EdanToledo
Copy link
Contributor

Describe the Bug

I'm 90% sure that the current implementation of PER does not correctly prevent partially broken/overwritten trajectories from being sampled.

To Reproduce

I'm not positive, but something like this provides a small indication:

@pytest.mark.parametrize("length", [9])
@pytest.mark.parametrize("add_batch_size", [16])
@pytest.mark.parametrize("sample_sequence_length", [3, 5])
@pytest.mark.parametrize("period", [1])
@pytest.mark.parametrize("max_length_time_axis", [20])
def test_prioritised_sample_doesnt_sample_prev_broken_trajectories(
    length: int,
    add_batch_size: int,
    sample_sequence_length: int,
    period: int,
    max_length_time_axis: int,
) -> None:
    """Test to ensure that `sample` avoids including rewards from broken
    trajectories.
    """
    fake_transition = {"reward": jnp.array([1])}

    offset = jnp.arange(add_batch_size).reshape(add_batch_size, 1, 1) * 1000

    buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer(
        add_batch_size=add_batch_size,
        sample_batch_size=2048,
        sample_sequence_length=sample_sequence_length,
        period=period,
        max_length_time_axis=max_length_time_axis,
        min_length_time_axis=sample_sequence_length,
    )

    rng_key = jax.random.PRNGKey(0)
    state = buffer.init(fake_transition)

    for i in range(10):
        fake_batch_sequence = {
            "reward": jnp.arange(length)
            .reshape(1, length, 1)
            .repeat(add_batch_size, axis=0)
            + offset
            + length * i
        }
        state = buffer.add(state, fake_batch_sequence)

        rng_key, rng_key1 = jax.random.split(rng_key)

        sample = buffer.sample(state, rng_key1)

        sampled_r = sample.experience["reward"]

        for b in range(sampled_r.shape[0]):
            assert is_strictly_increasing(sampled_r[b])

Expected Behavior

Numbers are always increasing, I think.

Outcome

All I know right now is that it's very possible to sample the broken sequences.

@EdanToledo EdanToledo added the bug Something isn't working label Jan 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant