Skip to content

Commit

Permalink
Merge pull request #1 from araffin/feat/gym-v1
Browse files Browse the repository at this point in the history
Gymnasium v1 support
  • Loading branch information
araffin authored Nov 5, 2024
2 parents 3070ad0 + b195c6f commit d2d7f39
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 87 deletions.
56 changes: 30 additions & 26 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ main ]
branches: [main]
pull_request:
branches: [ main ]
branches: [main]

jobs:
build:
Expand All @@ -20,29 +20,33 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip install .[tests]
- name: Lint with ruff
run: |
make lint
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
uv pip install --system ".[tests]"
- name: Lint with ruff
run: |
make lint
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
isort ${LINT_PATHS}
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
2 changes: 1 addition & 1 deletion dqn_tutorial/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def run_dqn(
# eval_render_mode="human",
# )

# import flappy_bird_gymnasium # noqa: F401
# import flappy_bird_gymnasium
#
# run_dqn(
# env_id="FlappyBird-v0",
Expand Down
35 changes: 28 additions & 7 deletions dqn_tutorial/dqn/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import warnings
from pathlib import Path
from typing import Optional

import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder

try:
from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder

gym_v1 = False
except ImportError:
from gymnasium.wrappers import RecordVideo

gym_v1 = True

from dqn_tutorial.dqn.collect_data import epsilon_greedy_action_selection
from dqn_tutorial.dqn.q_network import QNetwork
Expand All @@ -25,19 +34,28 @@ def evaluate_policy(
:param q_net: The Q-network to evaluate
:param n_eval_episodes: The number of episodes to evaluate the policy on
:param eval_exploration_rate: The exploration rate to use during evaluation
:param video_name: When set, the filename of the video to record.
"""
assert isinstance(eval_env.action_space, spaces.Discrete)

# Setup video recorder
video_recorder = None
if video_name is not None and eval_env.render_mode == "rgb_array":
video_path = Path(__file__).parent.parent.parent / "logs" / "videos" / video_name
video_path.parent.mkdir(parents=True, exist_ok=True)

video_recorder = VideoRecorder(
env=eval_env,
base_path=str(video_path),
)
if gym_v1:
# New gym recorder always wants to cut video into episodes,
# set video length big enough but not to inf (will cut into episodes)
# Silence warnings when the folder already exists
warnings.filterwarnings("ignore", category=UserWarning, module="gymnasium.wrappers.rendering")
eval_env = RecordVideo(eval_env, str(video_path.parent), step_trigger=lambda _: False, video_length=100_000)
eval_env.start_recording(video_name)
else:
video_recorder = VideoRecorder(
env=eval_env,
base_path=str(video_path),
)

assert isinstance(eval_env.action_space, spaces.Discrete)

episode_returns = []
for _ in range(n_eval_episodes):
Expand Down Expand Up @@ -70,6 +88,9 @@ def evaluate_policy(
if video_recorder is not None:
print(f"Saving video to {video_recorder.path}")
video_recorder.close()
elif isinstance(eval_env, RecordVideo):
print(f"Saving video to {video_path}.mp4")
eval_env.close()

# Print mean and std of the episode rewards
print(f"Mean episode reward: {np.mean(episode_returns):.2f} +/- {np.std(episode_returns):.2f}")
1 change: 1 addition & 0 deletions dqn_tutorial/fqi/fqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
by Ernst et al. and
"Neural fitted Q iteration" by Martin Riedmiller.
"""

from functools import partial
from pathlib import Path
from typing import Optional
Expand Down
21 changes: 8 additions & 13 deletions notebooks/1_fitted_q_iteration_fqi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@
"source": [
"import os\n",
"\n",
"from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder\n",
"from gymnasium.wrappers import RecordVideo\n",
"\n",
"\n",
"def evaluate(\n",
Expand All @@ -636,25 +636,20 @@
" done = False\n",
"\n",
" # Setup video recorder\n",
" video_recorder = None\n",
" if video_name is not None and env.render_mode == \"rgb_array\":\n",
" os.makedirs(\"../logs/videos/\", exist_ok=True)\n",
"\n",
" video_recorder = VideoRecorder(\n",
" env=env,\n",
" base_path=f\"../logs/videos/{video_name}\",\n",
" )\n",
" # New gym recorder always wants to cut video into episodes,\n",
" # set video length big enough but not to inf (will cut into episodes)\n",
" eval_env = RecordVideo(eval_env, \"../logs/videos\", step_trigger=lambda _: False, video_length=100_000)\n",
" eval_env.start_recording(video_name)\n",
"\n",
" current_obs, _ = env.reset()\n",
" # Number of discrete actions\n",
" n_actions = int(env.action_space.n)\n",
" assert isinstance(env.action_space, spaces.Discrete), \"FQI only support discrete actions\"\n",
"\n",
" while total_episodes < n_eval_episodes:\n",
" # Record video\n",
" if video_recorder is not None:\n",
" video_recorder.capture_frame()\n",
"\n",
" ### YOUR CODE HERE\n",
"\n",
" # Retrieve the q-values for the current observation\n",
Expand All @@ -678,9 +673,9 @@
" total_episodes += 1\n",
" current_obs, _ = env.reset()\n",
"\n",
" if video_recorder is not None:\n",
" print(f\"Saving video to {video_recorder.path}\")\n",
" video_recorder.close()\n",
" if isinstance(eval_env, RecordVideo):\n",
" print(f\"Saving video to ../logs/videos/{video_name}\")\n",
" eval_env.close()\n",
"\n",
" print(f\"Total reward = {np.mean(episode_returns):.2f} +/- {np.std(episode_returns):.2f}\")"
]
Expand Down
19 changes: 8 additions & 11 deletions notebooks/solutions/1_fitted_q_iteration_fqi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@
"source": [
"import os\n",
"\n",
"from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder\n",
"from gymnasium.wrappers import RecordVideo\n",
"\n",
"\n",
"def evaluate(\n",
Expand All @@ -617,19 +617,16 @@
" if video_name is not None and env.render_mode == \"rgb_array\":\n",
" os.makedirs(\"../logs/videos/\", exist_ok=True)\n",
"\n",
" video_recorder = VideoRecorder(\n",
" env=env,\n",
" base_path=f\"../logs/videos/{video_name}\",\n",
" )\n",
" # New gym recorder always wants to cut video into episodes,\n",
" # set video length big enough but not to inf (will cut into episodes)\n",
" eval_env = RecordVideo(eval_env, \"../logs/videos\", step_trigger=lambda _: False, video_length=100_000)\n",
" eval_env.start_recording(video_name)\n",
"\n",
" obs, _ = env.reset()\n",
" n_actions = int(env.action_space.n)\n",
" assert isinstance(env.action_space, spaces.Discrete), \"FQI only support discrete actions\"\n",
"\n",
" while total_episodes < n_eval_episodes:\n",
" # Record video\n",
" if video_recorder is not None:\n",
" video_recorder.capture_frame()\n",
"\n",
" ### YOUR CODE HERE\n",
"\n",
Expand Down Expand Up @@ -660,9 +657,9 @@
" total_episodes += 1\n",
" obs, _ = env.reset()\n",
"\n",
" if video_recorder is not None:\n",
" print(f\"Saving video to {video_recorder.path}\")\n",
" video_recorder.close()\n",
" if isinstance(eval_env, RecordVideo):\n",
" print(f\"Saving video to ../logs/videos/{video_name}\")\n",
" eval_env.close()\n",
"\n",
" print(f\"Total reward = {np.mean(episode_returns):.2f} +/- {np.std(episode_returns):.2f}\")"
]
Expand Down
43 changes: 18 additions & 25 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ packages = ["dqn_tutorial"]
name = "dqn_tutorial"
# version = "0.0.1", version is determined by setuptools_scm
dynamic = ["version"]
authors = [
{ name="Antonin Raffin", email="[email protected]" },
]
authors = [{ name = "Antonin Raffin", email = "[email protected]" }]
description = "A small example package"
readme = "README.md"
requires-python = ">=3.8"
Expand All @@ -21,14 +19,15 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"numpy",
"gymnasium[classic-control]>=0.28.1,<1.0",
"gymnasium[classic-control,other]>=0.29.1,<1.1.0",
"scikit-learn",
"torch>=1.7.0",
"torch>=2.4.0",
]

[tool.setuptools_scm]
Expand All @@ -41,10 +40,8 @@ tests = [
"pytest-cov",
# Type check
"mypy",
# Lint code (flake8 replacement)
# Lint code and format
"ruff",
# Sort imports
"isort>=5.0",
# Reformat
"black",
]
Expand All @@ -58,25 +55,23 @@ tests = [
line-length = 127
# Assume Python 3.8
target-version = "py38"

[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
# Ignore explicit stacklevel`
ignore = ["B028"]

[tool.ruff.per-file-ignores]
[tool.lint.ruff.per-file-ignores]


[tool.ruff.mccabe]
[tool.lint.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15

[tool.black]
line-length = 127

[tool.isort]
profile = "black"
line_length = 127
src_paths = ["dqn_tutorial"]

[tool.mypy]
ignore_missing_imports = true
Expand All @@ -90,23 +85,21 @@ show_error_codes = true

[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = [
"PYTHONHASHSEED=0"
]
# env = ["PYTHONHASHSEED=0"]

filterwarnings = [
]
filterwarnings = []
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
]

[tool.coverage.run]
disable_warnings = ["couldnt-parse"]
branch = false
omit = [
"tests/*",
"dqn_tutorial/_version.py",
]
omit = ["tests/*", "dqn_tutorial/_version.py"]

[tool.coverage.report]
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
exclude_lines = [
"pragma: no cover",
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]
Loading

0 comments on commit d2d7f39

Please sign in to comment.