Skip to content

Commit

Permalink
[BugFix] Fix broken gym tests (#1980)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 1, 2024
1 parent cb9c8c5 commit 146341b
Show file tree
Hide file tree
Showing 16 changed files with 711 additions and 553 deletions.
4 changes: 2 additions & 2 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_contro
if [ "${CU_VERSION:-}" != cpu ] ; then
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
--timeout=120
--timeout=120 --mp_fork_if_no_cuda
else
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \
--timeout=120
--timeout=120 --mp_fork_if_no_cuda
fi

coverage combine
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_distributed/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ export BATCHED_PIPE_TIMEOUT=60

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200 --mp_fork_if_no_cuda
coverage combine
coverage xml -i
75 changes: 26 additions & 49 deletions .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ apt-get update && apt-get install -y git wget libglew-dev libx11-dev x11proto-de
# solves "'extras_require' must be a dictionary"
pip install setuptools==65.3.0

mkdir third_party
cd third_party
git clone https://github.com/vmoens/gym
cd ..
#mkdir -p third_party
#cd third_party
#git clone https://github.com/vmoens/gym
#cd ..

# This version is installed initially (see environment.yml)
for GYM_VERSION in '0.13'
Expand All @@ -38,7 +38,7 @@ do

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env
conda env remove --prefix ./cloned_env -y
done

# gym[atari]==0.19 is broken, so we install only gym without dependencies.
Expand All @@ -57,7 +57,7 @@ do

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env
conda env remove --prefix ./cloned_env -y
done

# gym[atari]==0.20 installs ale-py==0.8, but this version is not compatible with gym<0.26, so we downgrade it.
Expand All @@ -76,7 +76,7 @@ do

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env
conda env remove --prefix ./cloned_env -y
done

for GYM_VERSION in '0.25'
Expand All @@ -92,7 +92,7 @@ do

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env
conda env remove --prefix ./cloned_env -y
done

# For this version "gym[accept-rom-license]" is required.
Expand All @@ -104,65 +104,42 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install 'gym[accept-rom-license]'==$GYM_VERSION
pip3 install 'gym[atari]'==$GYM_VERSION
pip3 install 'gym[atari,accept-rom-license]'==$GYM_VERSION
pip3 install gym-super-mario-bros
$DIR/run_test.sh

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env
conda env remove --prefix ./cloned_env -y
done

# For this version "gym[accept-rom-license]" is required.
for GYM_VERSION in '0.27'
for GYM_VERSION in '0.27' '0.28'
do
# Create a copy of the conda env and work with this
conda deactivate
conda create --prefix ./cloned_env --clone ./env -y
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install 'gymnasium[accept-rom-license]'==$GYM_VERSION


if [[ $OSTYPE != 'darwin'* ]]; then
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
if [[ $PY_VERSION == *"3.7"* ]]; then
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.8"* ]]; then
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.9"* ]]; then
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.10"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.11"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
pip install gymnasium[atari]
else
pip install gymnasium[atari]
fi
pip install mo-gymnasium
pip install gymnasium-robotics
pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION

$DIR/run_test.sh

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env
conda env remove --prefix ./cloned_env -y
done

# Latest gymnasium
conda deactivate
conda create --prefix ./cloned_env --clone ./env -y
conda activate ./cloned_env

pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U

$DIR/run_test.sh

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env -y
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_gym/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_te

export DISPLAY=':99.0'
Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 &
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips --mp_fork
coverage combine
coverage xml -i
79 changes: 55 additions & 24 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchrl.envs import MultiThreadedEnv, ObservationNorm
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
from torchrl.envs.libs.envpool import _has_envpool
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv
from torchrl.envs.transforms import (
Compose,
RewardClipping,
Expand All @@ -35,41 +35,72 @@
# Specified for test_utils.py
__version__ = "0.3"

# Default versions of the environments.
CARTPOLE_VERSIONED = "CartPole-v1"
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
PENDULUM_VERSIONED = "Pendulum-v1"
PONG_VERSIONED = "ALE/Pong-v5"

def CARTPOLE_VERSIONED():
# load gym
if gym_backend() is not None:
_set_gym_environments()
return _CARTPOLE_VERSIONED


def HALFCHEETAH_VERSIONED():
# load gym
if gym_backend() is not None:
_set_gym_environments()
return _HALFCHEETAH_VERSIONED


def PONG_VERSIONED():
# load gym
if gym_backend() is not None:
_set_gym_environments()
return _PONG_VERSIONED


def PENDULUM_VERSIONED():
# load gym
if gym_backend() is not None:
_set_gym_environments()
return _PENDULUM_VERSIONED


def _set_gym_environments():
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED

_CARTPOLE_VERSIONED = None
_HALFCHEETAH_VERSIONED = None
_PENDULUM_VERSIONED = None
_PONG_VERSIONED = None


@implement_for("gym", None, "0.21.0")
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED

CARTPOLE_VERSIONED = "CartPole-v0"
HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
PENDULUM_VERSIONED = "Pendulum-v0"
PONG_VERSIONED = "Pong-v4"
_CARTPOLE_VERSIONED = "CartPole-v0"
_HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
_PENDULUM_VERSIONED = "Pendulum-v0"
_PONG_VERSIONED = "Pong-v4"


@implement_for("gym", "0.21.0", None)
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED

CARTPOLE_VERSIONED = "CartPole-v1"
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
PENDULUM_VERSIONED = "Pendulum-v1"
PONG_VERSIONED = "ALE/Pong-v5"
_CARTPOLE_VERSIONED = "CartPole-v1"
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
_PENDULUM_VERSIONED = "Pendulum-v1"
_PONG_VERSIONED = "ALE/Pong-v5"


@implement_for("gymnasium")
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED

CARTPOLE_VERSIONED = "CartPole-v1"
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
PENDULUM_VERSIONED = "Pendulum-v1"
PONG_VERSIONED = "ALE/Pong-v5"
_CARTPOLE_VERSIONED = "CartPole-v1"
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
_PENDULUM_VERSIONED = "Pendulum-v1"
_PONG_VERSIONED = "ALE/Pong-v5"


if _has_gym:
Expand Down Expand Up @@ -171,7 +202,7 @@ def create_env_fn():
return GymEnv(env_name, frame_skip=frame_skip, device=device)

else:
if env_name == PONG_VERSIONED:
if env_name == PONG_VERSIONED():

def create_env_fn():
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
Expand Down Expand Up @@ -250,7 +281,7 @@ def _make_multithreaded_env(

torch.manual_seed(0)
multithreaded_kwargs = (
{"frame_skip": frame_skip} if env_name == PONG_VERSIONED else {}
{"frame_skip": frame_skip} if env_name == PONG_VERSIONED() else {}
)
env_multithread = MultiThreadedEnv(
N,
Expand All @@ -274,7 +305,7 @@ def _make_multithreaded_env(

def get_transform_out(env_name, transformed_in, obs_key=None):

if env_name == PONG_VERSIONED:
if env_name == PONG_VERSIONED():
if obs_key is None:
obs_key = "pixels"

Expand Down
26 changes: 24 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os

import sys
import time
import warnings
from collections import defaultdict

import pytest

CALL_TIMES = defaultdict(lambda: 0.0)
IS_OSX = sys.platform == "darwin"


def pytest_sessionfinish(maxprint=50):
Expand Down Expand Up @@ -97,6 +97,20 @@ def pytest_addoption(parser):
"--runslow", action="store_true", default=False, help="run slow tests"
)

parser.addoption(
"--mp_fork",
action="store_true",
default=False,
help="Use 'fork' start method for mp dedicated tests.",
)

parser.addoption(
"--mp_fork_if_no_cuda",
action="store_true",
default=False,
help="Use 'fork' start method for mp dedicated tests only if there is no cuda device available.",
)


def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
Expand All @@ -110,3 +124,11 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)


@pytest.fixture
def maybe_fork_ParallelEnv(request):
# Feature available from 0.4 only
from torchrl.envs import ParallelEnv

return ParallelEnv
2 changes: 1 addition & 1 deletion test/smoke_test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_gym():
assert _has_gym
from _utils_internal import PONG_VERSIONED

env = GymEnv(PONG_VERSIONED)
env = GymEnv(PONG_VERSIONED())
env.reset()


Expand Down
10 changes: 6 additions & 4 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def make_env():
# This is currently necessary as the methods in GymWrapper may have mismatching backend
# versions.
with set_gym_backend(gym_backend()):
return TransformedEnv(GymEnv(PONG_VERSIONED, frame_skip=4), StepCounter())
return TransformedEnv(GymEnv(PONG_VERSIONED(), frame_skip=4), StepCounter())

if parallel:
env = ParallelEnv(2, make_env)
Expand Down Expand Up @@ -1076,7 +1076,9 @@ def test_collector_vecnorm_envcreator(static_seed):
from torchrl.envs.libs.gym import GymEnv

num_envs = 4
env_make = EnvCreator(lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED), VecNorm()))
env_make = EnvCreator(
lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm())
)
env_make = ParallelEnv(num_envs, env_make)

policy = RandomPolicy(env_make.action_spec)
Expand Down Expand Up @@ -1293,7 +1295,7 @@ def test_collector_output_keys(

policy = SafeModule(**policy_kwargs)

env_maker = lambda: GymEnv(PENDULUM_VERSIONED)
env_maker = lambda: GymEnv(PENDULUM_VERSIONED())

policy(env_maker().reset())

Expand Down Expand Up @@ -1432,7 +1434,7 @@ class TestAutoWrap:
def env_maker(self):
from torchrl.envs.libs.gym import GymEnv

return lambda: GymEnv(PENDULUM_VERSIONED)
return lambda: GymEnv(PENDULUM_VERSIONED())

def _create_collector_kwargs(self, env_maker, collector_class, policy):
collector_kwargs = {
Expand Down
Loading

0 comments on commit 146341b

Please sign in to comment.