Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SyncDataCollector Crashes with Resources Leak During Data Collection #2644

Open
3 tasks done
AlexandreBrown opened this issue Dec 11, 2024 · 6 comments
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@AlexandreBrown
Copy link

Describe the bug

I've observed that lateset trainings crash after 180k steps with the following message :

micromamba/envs/dmc_env/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 4 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
Killed

To Reproduce

  1. Create the DMControlEnv as follow :
env = TransformedEnv(
    DMControlEnv(
        env_name="cartpole",
        task_name="swingup",
        from_pixels=True,
        pixels_only=True,
        device="cuda"
    )
)

env.append_transform(DoubleToFloat())

env.append_transform(
    FrameSkipTransform(frame_skip=2)
)
env.append_transform(InitTracker())

env.append_transform(
    PermuteTransform(
        dims=(-1, -2, -3), in_keys=["pixels"], out_keys=["pixels"]
    ),  # H W C -> C W H
)
env.append_transform(
    Resize(
        w=cfg["env"]["pixels"]["width"],
        h=cfg["env"]["pixels"]["height"],
        in_keys=["pixels"],
        out_keys=["pixels"],
    ),  # C W H -> C W' H'
)

env.append_transform(
    PermuteTransform(
        dims=(-3, -1, -2), in_keys=["pixels"], out_keys=["pixels"]
    ),  # C W' H' -> C H' W'
)

env.append_transform(
    UnsqueezeTransform(
        dim=-4,
        in_keys=["pixels"],
        out_keys=["pixels"],
    )  # C H' W' -> 1 C H' W'
)

env.append_transform(
    CatFrames(
        N=int(frame_stack),
        dim=-4,
        in_keys=["pixels"],
        out_keys=["pixels"],
    )  # 1 C H' W' -> N C H' W'
)

# Other transforms omitted for previty
  1. Create the replay buffer as follow :
import torch
from omegaconf import DictConfig
from torchrl.data import ReplayBuffer
from torchrl.data import TensorDictReplayBuffer
from torchrl.data import LazyMemmapStorage
from torchrl.data import LazyTensorStorage
from torchrl.envs.transforms import Compose
from torchrl.envs.transforms import ExcludeTransform
from segdac.action_scaling.env_action_scaler import TanhEnvActionScaler
from segdac_dev.envs.transforms.unscale_image import UnscaleImage
from segdac_dev.envs.transforms.unscale_action import UnscaleAction
from hydra.utils import instantiate


def get_replay_buffer_data_saving_transforms(cfg: DictConfig) -> list:
    """
    These are transforms executed when saving data to the replay buffer.
    We want to exclude pixels_transformed because it is in float32 (expensive to store), we can store the uint8 RGB image instead.
    """
    transforms = [
        ExcludeTransform(
            "pixels_transformed", ("next", "pixels_transformed"), inverse=True
        ),
    ]

    for save_transform_config in (
        cfg.get("algo", {}).get("replay_buffer", {}).get("save_transforms", [])
    ):
        save_transform = instantiate(save_transform_config)
        assert save_transform.inverse is True

        if isinstance(save_transform, ExcludeTransform):
            arg_key = "_args_"
            next_args = []
            for key in save_transform_config.get(arg_key, []):
                next_args.append(("next", key))
            save_transform.excluded_keys = save_transform.excluded_keys + next_args
        elif isinstance(save_transform, UnscaleImage):
            arg_key = "in_keys_inv"
            next_in_keys_inv_args = []
            for key in save_transform_config.get(arg_key, []):
                next_in_keys_inv_args.append(("next", key))
            save_transform.in_keys_inv = (
                save_transform.in_keys_inv + next_in_keys_inv_args
            )
            arg_key = "out_keys_inv"
            next_out_keys_inv_args = []
            for key in save_transform_config.get(arg_key, []):
                next_out_keys_inv_args.append(("next", key))
            save_transform.out_keys_inv = (
                save_transform.out_keys_inv + next_out_keys_inv_args
            )

        transforms.append(save_transform)

    return transforms


def get_replay_buffer_sample_transforms(
    cfg: DictConfig, env_action_scaler: TanhEnvActionScaler
) -> list:
    """
    These are transforms executed when sampling data from the replay buffer.
    """
    transforms = []
    for sample_transform_config in (
        cfg.get("algo", {}).get("replay_buffer", {}).get("sample_transforms", [])
    ):
        sample_transform = instantiate(sample_transform_config)
        sample_transform.in_keys = sample_transform.in_keys + [
            ("next", key) for key in sample_transform.in_keys
        ]
        sample_transform.out_keys = sample_transform.out_keys + [
            ("next", key) for key in sample_transform.out_keys
        ]
        transforms.append(sample_transform)

    transforms.append(UnscaleAction(env_action_scaler))

    return transforms


def create_replay_buffer(cfg: DictConfig, env_action_scaler) -> ReplayBuffer:
    storage_device = torch.device(cfg["storage_device"]) # cpu in my test
    capacity = cfg["algo"]["replay_buffer"]["capacity"] # 1M in my test

    transforms = []
    transforms.extend(get_replay_buffer_data_saving_transforms(cfg))
    transforms.extend(get_replay_buffer_sample_transforms(cfg, env_action_scaler))
    transform = Compose(*transforms)

    storage_kwargs = {}
    storage_kwargs["max_size"] = capacity
    storage_kwargs["device"] = storage_device
    storage_dim = 1
    if cfg["env"]["num_workers"] > 1: # In my test num_workers = 1
        storage_dim += 1
    storage_kwargs["ndim"] = storage_dim

    if "cpu" in storage_device.type: # cpu was used in my test
        # LazyMemmapStorage is only supported on CPU
        replay_buffer = TensorDictReplayBuffer(
            storage=LazyMemmapStorage(**storage_kwargs),
            transform=transform,
            batch_size=int(cfg["training"]["batch_size"]), # 128 in my test
        )
    else:
        replay_buffer = TensorDictReplayBuffer(
            storage=LazyTensorStorage(**storage_kwargs),
            transform=transform,
            batch_size=int(cfg["training"]["batch_size"]),
        )

    return replay_buffer
  1. Create the sync data collector :
SyncDataCollector(
        create_env_fn=env,
        policy=policy,
        total_frames=data_collector_cfg["total_frames"], # 1M in my test
        max_frames_per_traj=max_frames_per_traj, # 1000 in my test
        frames_per_batch=frames_per_batch, # 1 in my test
        env_device=cfg["env"]["device"], # cuda in my test
        storing_device=cfg["storage_device"], # cpu in my test
        policy_device=cfg["policy_device"], # cuda in my test
        exploration_type=exploration_type, # RANDOM
        init_random_frames=data_collector_cfg.get("init_random_frames", 0), # 1000 in my test
        postproc=None,
)
  1. yield from the data collector (crash occurs at ~180k steps for me, 2 trainings in a row) :
from tqdm import tqdm

num_iters = 1_000_000
for data in tqdm(
    self.train_data_collector, "Env Data Collection", total=num_iters
):
    env_step += self.train_frames_per_batch

    self.replay_buffer.extend(data)

Expected behavior

No crash

Screenshots

If applicable, add screenshots to help explain your problem.

System info

  • CPU : 6
  • GPU : 1xA100
  • Disk : 100GB
  • RAM : 48GB
  • Headless : yes (cluster)
  • Python : 3.10.16
  • TorchRL : 0.6.0
  • Torch: 2.5.1
  • pip list :
Package                   Version                                                       Editable project location
------------------------- ------------------------------------------------------------- -------------------------------------
absl-py                   2.1.0
annotated-types           0.7.0
antlr4-python3-runtime    4.9.3
asttokens                 3.0.0
attrs                     24.2.0
av                        13.1.0
certifi                   2024.8.30
charset-normalizer        3.4.0
click                     8.1.7
clip                      1.0
cloudpickle               3.1.0
coloredlogs               15.0.1
comet-ml                  3.47.1
comm                      0.2.2
configobj                 5.0.9
contourpy                 1.3.1
cycler                    0.12.1
Cython                    3.0.11
debugpy                   1.8.9
decorator                 5.1.1
diffusers                 0.31.0
dm_control                1.0.25
dm-env                    1.6
dm-tree                   0.1.8
docker-pycreds            0.4.0
dulwich                   0.22.6
efficientvit              0.0.0
einops                    0.8.0
etils                     1.11.0
everett                   3.1.0
exceptiongroup            1.2.2
executing                 2.1.0
filelock                  3.16.1
flatbuffers               24.3.25
fonttools                 4.55.2
fsspec                    2024.10.0
ftfy                      6.3.1
gitdb                     4.0.11
GitPython                 3.1.43
glfw                      2.8.0
huggingface-hub           0.26.2
humanfriendly             10.0
hydra-core                1.3.2
idna                      3.10
igraph                    0.11.8
imageio                   2.36.1
importlib_metadata        8.5.0
importlib_resources       6.4.5
ipdb                      0.13.13
ipykernel                 6.29.5
ipython                   8.30.0
ipywidgets                8.1.5
jedi                      0.19.2
Jinja2                    3.1.4
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
jupyterlab_widgets        3.0.13
kiwisolver                1.4.7
labmaze                   1.0.6
lazy_loader               0.4
lightning-utilities       0.11.9
loguru                    0.7.2
lvis                      0.5.3
lxml                      5.3.0
markdown-it-py            3.0.0
MarkupSafe                3.0.2
matplotlib                3.9.3
matplotlib-inline         0.1.7
mdurl                     0.1.2
mpmath                    1.3.0
mujoco                    3.2.6
nest-asyncio              1.6.0
networkx                  3.4.2
numpy                     1.26.4
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
omegaconf                 2.3.0
onnx                      1.17.0
onnxruntime               1.20.1
onnxsim                   0.4.36
opencv-python             4.10.0.84
opencv-python-headless    4.10.0.84
orjson                    3.10.12
packaging                 24.2
pandas                    2.2.3
parso                     0.8.4
pexpect                   4.9.0
pillow                    11.0.0
pip                       24.3.1
platformdirs              4.3.6
prompt_toolkit            3.0.48
protobuf                  5.29.1
psutil                    6.1.0
ptyprocess                0.7.0
pure_eval                 0.2.3
py-cpuinfo                9.0.0
pycocotools               2.0.8
pydantic                  2.10.3
pydantic_core             2.27.1
Pygments                  2.18.0
PyOpenGL                  3.1.7
PyOpenGL-accelerate       3.1.7
pyparsing                 3.2.0
python-box                6.1.0
python-dateutil           2.9.0.post0
pytz                      2024.2
PyYAML                    6.0.2
pyzmq                     26.2.0
referencing               0.35.1
regex                     2024.11.6
requests                  2.32.3
requests-toolbelt         1.0.0
rich                      13.9.4
rpds-py                   0.22.3
ruamel.yaml               0.18.6
ruamel.yaml.clib          0.2.12
safetensors               0.4.5
scikit-image              0.24.0
scipy                     1.14.1
seaborn                   0.13.2
XXX                    0.0.1                                                       
XXX                0.0.1                                                   
segment_anything          1.0
semantic-version          2.10.0
sentry-sdk                2.19.2
setproctitle              1.3.4
setuptools                75.6.0
simplejson                3.19.3
six                       1.17.0
smmap                     5.0.1
stack-data                0.6.3
sympy                     1.13.1
tensordict                0.6.2
texttable                 1.7.0
tifffile                  2024.9.20
timm                      1.0.12
TinyNeuralNetwork         0.1.0.20241202154922+f79b0ccf02a92247c9cae4ac403c33917f8f6f6f
tokenizers                0.21.0
tomli                     2.2.1
torch                     2.5.1
torch-fidelity            0.3.0
torchaudio                2.5.1
torchmetrics              1.6.0
torchprofile              0.0.4
torchrl                   0.6.0
torchvision               0.20.1
tornado                   6.4.2
tqdm                      4.66.5
traitlets                 5.14.3
transformers              4.47.0
triton                    3.1.0
typing_extensions         4.12.2
tzdata                    2024.2
ultralytics               8.3.48
ultralytics-thop          2.0.13
urllib3                   2.2.3
wandb                     0.19.0
wcwidth                   0.2.13
wheel                     0.45.1
widgetsnbextension        4.0.13
wrapt                     1.17.0
wurlitzer                 3.1.1
zipp                      3.21.0
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

output :

0.6.0 1.26.4 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0] linux

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@AlexandreBrown AlexandreBrown added the bug Something isn't working label Dec 11, 2024
@AlexandreBrown
Copy link
Author

Maybe there is a common/root cause between #2614 and this issue

@AlexandreBrown
Copy link
Author

Maybe it's related to my disk storage being too small ? I'm storing stacked frames (4, 3, 84, 84) into my replay buffer which uses LazyMemmapStorage.
Could be related to #914

@yu-fz
Copy link
Contributor

yu-fz commented Dec 17, 2024

I have ran into a similar problem before. When I was using torchRL with IsaacLab, I would have training runs die midway through when using SyncDataCollector. I made a wrapper for SyncDataCollector that overloaded the iterator() function to remove the CUDA memory management stuff in the beginning, and that seemed to fix the problem.

https://github.com/isaac-sim/IsaacLab/pull/1178/files#diff-82f19e3b1196887a446d1932e2626119f009999373febb153e0fe60e422da9aa


class SyncDataCollectorWrapper(SyncDataCollector):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def iterator(self) -> Iterator[TensorDictBase]:
        """Iterates through the DataCollector.
        Yields: TensorDictBase objects containing (chunks of) trajectories
        """
        # The portion of the code handling cuda streams has been removed in this inherited method, which
        # caused CUDA memory allocation issues with IsaacSim during env stepping.
        total_frames = self.total_frames

@vmoens
Copy link
Contributor

vmoens commented Dec 18, 2024

Is this what you were running into @AlexandreBrown?
@yu-fz Do you think I should patch SyncDataCollector to allow users to turn off the stream? Another option would be to deactive that whenever self.return_same_td=False (

if self.return_same_td:
# This is used with multiprocessed collectors to use the buffers
# stored in the tensordict.
if events:
for event in events:
event.record()
event.synchronize()
yield tensordict_out
) - IIRC this is the only scenario where this is really needed and you won't be using multiprocessed collectors with Isaac (presumably)

@fyu-bdai
Copy link

I think either would work! There shouldn't be a case where one would use multiprocessed collectors with IsaacSim.

@AlexandreBrown
Copy link
Author

I haven't tried the fix but I'm not opposed to having the option.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants