Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into release/0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 31, 2024
2 parents b9cf712 + 69453a6 commit 52a12a8
Show file tree
Hide file tree
Showing 120 changed files with 1,136 additions and 224 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
env.backend=gymnasium \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \
optim.gradient_steps=55 \
Expand Down
9 changes: 8 additions & 1 deletion examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,14 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend:
exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
logger = get_logger(
cfg.logger.backend, logger_name="a2c", experiment_name=exp_name
cfg.logger.backend,
logger_name="a2c",
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Create test environment
Expand Down
9 changes: 8 additions & 1 deletion examples/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend:
exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
logger = get_logger(
cfg.logger.backend, logger_name="a2c", experiment_name=exp_name
cfg.logger.backend,
logger_name="a2c",
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Create test environment
Expand Down
2 changes: 2 additions & 0 deletions examples/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ collector:
# logger
logger:
backend: wandb
project_name: torchrl_example_a2c
group_name: null
exp_name: Atari_Schulman17
test_interval: 40_000_000
num_test_episodes: 3
Expand Down
4 changes: 3 additions & 1 deletion examples/a2c/config_mujoco.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# task and env
env:
env_name: HalfCheetah-v3
env_name: HalfCheetah-v4

# collector
collector:
Expand All @@ -10,6 +10,8 @@ collector:
# logger
logger:
backend: wandb
project_name: torchrl_example_a2c
group_name: null
exp_name: Mujoco_Schulman17
test_interval: 1_000_000
num_test_episodes: 5
Expand Down
7 changes: 7 additions & 0 deletions examples/bandits/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Bandits example

## Note:
This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the
benchmarking of future releases, to ensure that it can be successfully run with the release code and that the
results are consistent. For now, be aware that this additional check has not been performed in the case of this
specific example.
9 changes: 7 additions & 2 deletions examples/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@
@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
# Create logger
exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name)
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="cql_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
# Set seeds
torch.manual_seed(cfg.env.seed)
Expand Down
9 changes: 7 additions & 2 deletions examples/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,19 @@
@hydra.main(version_base="1.1", config_path=".", config_name="online_config")
def main(cfg: "DictConfig"): # noqa: F821
# Create logger
exp_name = generate_exp_name("CQL-online", cfg.env.exp_name)
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="cql_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Set seeds
Expand Down
4 changes: 3 additions & 1 deletion examples/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ env:
name: CartPole-v1
task: ""
backend: gym
exp_name: cql_cartpole_gym
n_samples_stats: 1000
max_episode_steps: 200
seed: 0
Expand All @@ -24,6 +23,9 @@ collector:
# Logger
logger:
backend: wandb
project_name: torchrl_example_cql
group_name: null
exp_name: cql_cartpole_gym
log_interval: 5000 # record interval in frames
eval_steps: 200
mode: online
Expand Down
8 changes: 6 additions & 2 deletions examples/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.optim.device)

# Create logger
exp_name = generate_exp_name("DiscreteCQL", cfg.env.exp_name)
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="discretecql_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
},
)

# Set seeds
Expand Down
6 changes: 4 additions & 2 deletions examples/cql/offline_config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# env and task
env:
name: Hopper-v2
name: Hopper-v4
task: ""
library: gym
exp_name: cql_${replay_buffer.dataset}
n_samples_stats: 1000
seed: 0
backend: gym # D4RL uses gym so we make sure gymnasium is hidden

# logger
logger:
backend: wandb
project_name: torchrl_example_cql
group_name: null
exp_name: cql_${replay_buffer.dataset}
eval_iter: 5000
eval_steps: 1000
mode: online
Expand Down
4 changes: 3 additions & 1 deletion examples/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
env:
name: Pendulum-v1
task: ""
exp_name: cql_${env.name}
n_samples_stats: 1000
seed: 0
train_num_envs: 1
Expand All @@ -23,6 +22,9 @@ collector:
# logger
logger:
backend: wandb
project_name: torchrl_example_cql
group_name: null
exp_name: cql_${env.name}
log_interval: 5000 # record interval in frames
eval_steps: 1000
mode: online
Expand Down
8 changes: 5 additions & 3 deletions examples/ddpg/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# environment and task
env:
name: HalfCheetah-v3
name: HalfCheetah-v4
task: ""
exp_name: ${env.name}_DDPG
library: gymnasium
max_episode_steps: 1000
seed: 42
Expand All @@ -22,7 +21,7 @@ collector:
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: ${env.exp_name}_${env.seed}
scratch_dir: ${logger.exp_name}_${env.seed}

# optimization
optim:
Expand All @@ -44,5 +43,8 @@ network:
# logging
logger:
backend: wandb
project_name: torchrl_example_ddpg
group_name: null
exp_name: ${env.name}_DDPG
mode: online
eval_iter: 25000
9 changes: 7 additions & 2 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,19 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("DDPG", cfg.env.exp_name)
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="ddpg_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Set seeds
Expand Down
4 changes: 3 additions & 1 deletion examples/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# environment and task
env:
name: HalfCheetah-v3
name: HalfCheetah-v4
task: ""
library: gym
stacked_frames: 20
Expand All @@ -20,7 +20,9 @@ env:
# logger
logger:
backend: wandb
project_name: torchrl_example_dt
model_name: DT
group_name: null
exp_name: DT-HalfCheetah-medium-v2
pretrain_log_interval: 500 # record interval in frames
fintune_log_interval: 1
Expand Down
4 changes: 3 additions & 1 deletion examples/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# environment and task
env:
name: HalfCheetah-v3
name: HalfCheetah-v4
task: ""
library: gym
stacked_frames: 20
Expand All @@ -20,6 +20,8 @@ env:
# logger
logger:
backend: wandb
project_name: torchrl_example_odt
group_name: null
exp_name: oDT-HalfCheetah-medium-v2
model_name: oDT
pretrain_log_interval: 500 # record interval in frames
Expand Down
9 changes: 5 additions & 4 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,18 @@ def make_dt_optimizer(optim_cfg, loss_module):


def make_logger(cfg):
from omegaconf import OmegaConf

if not cfg.logger.backend:
return None
exp_name = generate_exp_name(cfg.logger.model_name, cfg.logger.exp_name)
cfg.logger.exp_name = exp_name
logger = get_logger(
cfg.logger.backend,
logger_name=cfg.logger.model_name,
experiment_name=exp_name,
wandb_kwargs={"config": OmegaConf.to_container(cfg)},
wandb_kwargs={
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
return logger

Expand Down
6 changes: 4 additions & 2 deletions examples/discrete_sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
env:
name: CartPole-v1
task: ""
exp_name: ${env.name}_DiscreteSAC
library: gym
seed: 42
max_episode_steps: 500
Expand All @@ -23,7 +22,7 @@ collector:
replay_buffer:
prb: 0 # use prioritized experience replay
size: 1000000
scratch_dir: ${env.exp_name}_${env.seed}
scratch_dir: ${logger.exp_name}_${env.seed}

# optim
optim:
Expand All @@ -48,5 +47,8 @@ network:
# logging
logger:
backend: wandb
project_name: torchrl_example_discrete_sac
group_name: null
exp_name: ${env.name}_DiscreteSAC
mode: online
eval_iter: 5000
9 changes: 7 additions & 2 deletions examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,19 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("DiscreteSAC", cfg.env.exp_name)
exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="DiscreteSAC_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Set seeds
Expand Down
4 changes: 3 additions & 1 deletion examples/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ buffer:

# logger
logger:
backend: null
backend: wandb
project_name: torchrl_example_dqn
group_name: null
exp_name: DQN
test_interval: 1_000_000
num_test_episodes: 3
Expand Down
4 changes: 3 additions & 1 deletion examples/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ buffer:

# logger
logger:
backend: null
backend: wandb
project_name: torchrl_example_dqn
group_name: null
exp_name: DQN
test_interval: 50_000
num_test_episodes: 5
Expand Down
9 changes: 8 additions & 1 deletion examples/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend:
exp_name = generate_exp_name("DQN", f"Atari_mnih15_{cfg.env.env_name}")
logger = get_logger(
cfg.logger.backend, logger_name="dqn", experiment_name=exp_name
cfg.logger.backend,
logger_name="dqn",
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Create the test environment
Expand Down
Loading

0 comments on commit 52a12a8

Please sign in to comment.