Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add standalone scripts to enable torchrl workflow #1179

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions source/extensions/omni.isaac.lab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
---------

0.24.19 (2024-10-07)
~~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added ``play.py`` and ``train.py`` scripts to support new torchrl workflow.


0.24.19 (2024-10-05)
~~~~~~~~~~~~~~~~~~~~

Expand Down
150 changes: 150 additions & 0 deletions source/standalone/workflows/torchrl/cli_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.

from __future__ import annotations

import argparse
import random
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from omni.isaac.lab_tasks.utils.wrappers.torchrl import OnPolicyPPORunnerCfg


def add_torchrl_args(parser: argparse.ArgumentParser):
"""Add TorchRL arguments to the parser.

Adds the following fields to argparse:
- "--experiment_name" : Name of the experiment folder where logs will be stored (default: None).
- "--run_name" : Run name suffix to the log directory (default: None).
- "--resume" : Whether to resume from a checkpoint (default: None).
- "--load_run" : Name of the run folder to resume from (default: None).
- "--checkpoint" : Checkpoint file to resume from (default: None).
- "--logger" : Logger module to use (default: None).
- "--log_project_name" : Name of the logging project when using wandb or neptune (default: None).
Args:
parser: The parser to add the arguments to.
"""
# create a new argument group
arg_group = parser.add_argument_group("torchrl", description="Arguments for RSL-RL agent.")
# -- experiment arguments
arg_group.add_argument(
"--experiment_name",
type=str,
default=None,
help="Name of the experiment folder where logs will be stored.",
)
arg_group.add_argument(
"--run_name",
type=str,
default=None,
help="Run name suffix to the log directory.",
)
# -- load arguments
arg_group.add_argument(
"--resume",
type=bool,
default=None,
help="Whether to resume from a checkpoint.",
)
arg_group.add_argument(
"--load_run",
type=str,
default=None,
help="Name of the run folder to resume from.",
)
arg_group.add_argument(
"--checkpoint",
type=str,
default=None,
help="Checkpoint file to resume from.",
)
# -- logger arguments
arg_group.add_argument(
"--logger",
type=str,
default=None,
choices={"wandb", "tensorboard", "neptune"},
help="Logger module to use.",
)
arg_group.add_argument(
"--log_project_name",
type=str,
default=None,
help="Name of the logging project when using wandb or neptune.",
)


def parse_torchrl_cfg(task_name: str, args_cli: argparse.Namespace) -> OnPolicyPPORunnerCfg:
"""Parse configuration for RSL-RL agent based on inputs.

Args:
task_name: The name of the environment.
args_cli: The command line arguments.

Returns:
The parsed configuration for RSL-RL agent based on inputs.
"""
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry

# load the default configuration
torchrl_cfg: OnPolicyPPORunnerCfg = load_cfg_from_registry(task_name, "torchrl_cfg_entry_point")

# override the default configuration with CLI arguments
torchrl_cfg.device = "cpu" if args_cli.cpu else f"cuda:{args_cli.physics_gpu}"

# override the default configuration with CLI arguments
if args_cli.seed is not None:
torchrl_cfg.seed = args_cli.seed
if args_cli.resume is not None:
torchrl_cfg.resume = args_cli.resume
if args_cli.load_run is not None:
torchrl_cfg.load_run = args_cli.load_run
if args_cli.checkpoint is not None:
torchrl_cfg.load_checkpoint = args_cli.checkpoint
if args_cli.run_name is not None:
torchrl_cfg.run_name = args_cli.run_name
if args_cli.logger is not None:
torchrl_cfg.logger = args_cli.logger
# set the project name for wandb and neptune
if torchrl_cfg.logger == "wandb" and args_cli.log_project_name:
torchrl_cfg.wandb_project = args_cli.log_project_name

return torchrl_cfg


def update_torchrl_cfg(agent_cfg: OnPolicyPPORunnerCfg, args_cli: argparse.Namespace):
"""Update configuration for torchrl agent based on inputs.

Args:
agent_cfg: The configuration for torchrl agent.
args_cli: The command line arguments.

Returns:
The updated configuration for torchrl agent based on inputs.
"""
# override the default configuration with CLI arguments
if hasattr(args_cli, "seed") and args_cli.seed is not None:
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
agent_cfg.seed = args_cli.seed
if args_cli.resume is not None:
agent_cfg.resume = args_cli.resume
if args_cli.load_run is not None:
agent_cfg.load_run = args_cli.load_run
if args_cli.checkpoint is not None:
agent_cfg.load_checkpoint = args_cli.checkpoint
if args_cli.run_name is not None:
agent_cfg.run_name = args_cli.run_name
if args_cli.logger is not None:
agent_cfg.logger = args_cli.logger
# set the project name for wandb and neptune
if agent_cfg.logger in {"wandb"} and args_cli.log_project_name:
agent_cfg.wandb_project = args_cli.log_project_name

return agent_cfg
132 changes: 132 additions & 0 deletions source/standalone/workflows/torchrl/play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Script to play a checkpoint for an RL agent from TorchRL."""

"""Launch Isaac Sim Simulator first."""

import argparse

from omni.isaac.lab.app import AppLauncher

# local imports
import cli_args # isort: skip

# add argparse arguments
parser = argparse.ArgumentParser(description="Play an RL agent with TorchRL.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during play.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")

# append torchrl cli arguments
cli_args.add_torchrl_args(parser)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True

# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app

"""Rest everything follows."""

import gymnasium as gym
import os
import torch

from torchrl.envs.utils import ExplorationType, set_exploration_type

from omni.isaac.lab.utils.dict import print_dict

import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.torchrl import (
OnPolicyPPORunner,
OnPolicyPPORunnerCfg,
TorchRLEnvWrapper,
export_policy_as_onnx,
)


def main():
"""Play with TorchRL agent."""
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
agent_cfg: OnPolicyPPORunnerCfg = cli_args.parse_torchrl_cfg(args_cli.task, args_cli)

# specify directory for logging experiments
log_root_path = os.path.join("logs", "torchrl", agent_cfg.experiment_name)
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
log_dir = os.path.dirname(resume_path)

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# wrap around environment for rsl-rl
env = TorchRLEnvWrapper(env)

print(f"environment observation spec: {env.observation_spec}")
print(f"environment action spec: {env.action_spec}")

print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model
ppo_runner = OnPolicyPPORunner(env, agent_cfg, log_dir=None, device=agent_cfg.device)
ppo_runner.load(resume_path, eval_mode=True)

# obtain the trained policy for inference
policy = ppo_runner.actor_module

# export policy to onnx/jit
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_onnx(ppo_runner.loss_module_cfg, normalizer=None, path=export_model_dir, filename="policy.onnx")

# reset environment
td = env.reset()
timestep = 0
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with set_exploration_type(ExplorationType.MEAN), torch.inference_mode():
# agent stepping
td = policy(td)
td = env.step(td)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break

# close the simulator
env.close()


if __name__ == "__main__":
# run the main function
main()
# close sim app
simulation_app.close()
Loading