Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SyncDataCollector Crashes when init_random_frames=0 with a policy that is NOT random #2534

Open
3 tasks done
AlexandreBrown opened this issue Nov 4, 2024 · 2 comments · Fixed by #2645
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@AlexandreBrown
Copy link

AlexandreBrown commented Nov 4, 2024

Describe the bug

When yielding from a SyncDataCollector that uses a standard Actor (not a random policy) and init_random_frames=0, it crashes.

policy = Actor(
        agent,
        in_keys=["your_key"],
        out_keys=["action"],
        spec=train_env.action_spec,
    )
train_data_collector = SyncDataCollector(
        create_env_fn=train_env,
        policy=policy,
        init_random_frames=0,
        ...
    )

Yielding example that causes the crash :

for data in tqdm(train_data_collector, "Env Data Collection"):

To Reproduce

  1. Create an actor that is not RandomPolicy
  2. Create a SyncDataCollector with the actor and set init_random_frames=0.
  3. Try to yield from the data collector
  4. Observe the crash

Stack trace:

2024-11-04 12:04:33,606 [torchrl][INFO] check_env_specs succeeded!
2024-11-04 12:04:36.365 | INFO     | __main__:main:60 - Policy Device: cuda
2024-11-04 12:04:36.365 | INFO     | __main__:main:61 - Env Device: cpu
2024-11-04 12:04:36.365 | INFO     | __main__:main:62 - Storage Device: cpu
Env Data Collection:   0%|                                                                                                                                      | 0/1000000 [00:00<?, ?it/s]
Error executing job with overrides: ['env=dmc_reacher_hard', 'algo=sac_pixels']
Traceback (most recent call last):
  File "/home/user/Documents/SegDAC/./scripts/train_rl.py", line 119, in main
    trainer.train()
  File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 40, in train
    for data in tqdm(self.train_data_collector, "Env Data Collection"):
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
    yield from self.iterator()
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1035, in iterator
    tensordict_out = self.rollout()
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1166, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 2862, in step_and_maybe_reset
    tensordict = self.step(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 1505, in step
    next_tensordict = self._step(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 783, in _step
    tensordict_in = self.transform.inv(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/nn/common.py", line 314, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 357, in inv
    out = self._inv_call(clone(tensordict))
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 1084, in _inv_call
    tensordict = t._inv_call(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 3656, in _inv_call
    return super()._inv_call(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 342, in _inv_call
    raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
KeyError: "'action' not found in tensordict TensorDict(\n    fields={\n        collector: TensorDict(\n            fields={\n                traj_ids: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},\n            batch_size=torch.Size([]),\n            device=cpu,\n            is_shared=False),\n        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        pixels: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),\n        pixels_transformed: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),\n        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),\n        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n    batch_size=torch.Size([]),\n    device=cpu,\n    is_shared=False)"

Expected behavior

We should be able to yield with init_random_frames = 0

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...) : pip install torchrl==0.6.0
  • Python version : 3.10
  • Versions of any other relevant libraries: output of my pip list :
Package                   Version    Editable project location
------------------------- ---------- -------------------------------------------
absl-py                   2.1.0
antlr4-python3-runtime    4.9.3
attrs                     24.2.0
av                        13.1.0
certifi                   2024.8.30
charset-normalizer        3.4.0
cloudpickle               3.1.0
comet-ml                  3.47.1
configobj                 5.0.9
dm_control                1.0.24
dm-env                    1.6
dm-tree                   0.1.8
dulwich                   0.22.4
etils                     1.10.0
everett                   3.1.0
filelock                  3.16.1
fsspec                    2024.10.0
glfw                      2.7.0
hydra-core                1.3.2
idna                      3.10
importlib_resources       6.4.5
Jinja2                    3.1.4
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
labmaze                   1.0.6
loguru                    0.7.2
lxml                      5.3.0
markdown-it-py            3.0.0
MarkupSafe                3.0.2
mdurl                     0.1.2
mpmath                    1.3.0
mujoco                    3.2.4
networkx                  3.4.2
numpy                     2.1.3
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
omegaconf                 2.3.0
orjson                    3.10.11
packaging                 24.1
pillow                    11.0.0
pip                       24.3.1
protobuf                  5.28.3
psutil                    6.1.0
Pygments                  2.18.0
PyOpenGL                  3.1.7
pyparsing                 3.2.0
python-box                6.1.0
PyYAML                    6.0.2
referencing               0.35.1
requests                  2.32.3
requests-toolbelt         1.0.0
rich                      13.9.4
rpds-py                   0.20.1
scipy                     1.14.1
semantic-version          2.10.0
sentry-sdk                2.18.0
setuptools                75.3.0
simplejson                3.19.3
sympy                     1.13.1
tensordict                0.6.1
torch                     2.5.1
torchaudio                2.5.1
torchrl                   0.6.0
torchvision               0.20.1
tqdm                      4.66.5
triton                    3.1.0
typing_extensions         4.12.2
urllib3                   2.2.3
wheel                     0.44.0
wrapt                     1.16.0
wurlitzer                 3.1.1
zipp                      3.20.2

Reason and Possible fixes

It seems like self._policy_output_keys from SyncDataCollector::_make_final_rollout is set to {} when init_random_frames=0 which causes an unwanted behavior in SyncDataCollector::rollout.
More precisely, these lines from SyncDataCollector::rollout :

policy_output = self.policy(policy_input)
if self._shuttle is not policy_output:
    # ad-hoc update shuttle
    self._shuttle.update(
        policy_output, keys_to_update=self._policy_output_keys
    )

In my case, policy_output was a tensor with the action key, but since self._policy_output_keys is {}, this means that self._shuttle is never updated to have the action key. This causes a crash with the error KeyError: "'action' not found in tensordict

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@AlexandreBrown AlexandreBrown added the bug Something isn't working label Nov 4, 2024
@AlexandreBrown AlexandreBrown changed the title [BUG] SyncDataCollector Crashes when policy_device=cuda and init_random_frames=0 [BUG] SyncDataCollector Crashes when init_random_frames=0 Nov 4, 2024
@vmoens
Copy link
Contributor

vmoens commented Nov 4, 2024

On it!

@jannessm
Copy link

+1

@vmoens vmoens linked a pull request Dec 11, 2024 that will close this issue
@AlexandreBrown AlexandreBrown changed the title [BUG] SyncDataCollector Crashes when init_random_frames=0 [BUG] SyncDataCollector Crashes when init_random_frames=0 with a policy that is NOT random Dec 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants