diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ffdc112..74e813e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ name: CI on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] jobs: build: @@ -20,29 +20,33 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip3 install torch --index-url https://download.pytorch.org/whl/cpu - pip install .[tests] - - name: Lint with ruff - run: | - make lint - - name: Check codestyle - run: | - make check-codestyle - - name: Type check - run: | - make type - - name: Test with pytest - run: | - make pytest + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # Use uv for faster downloads + pip install uv + # cpu version of pytorch + # See https://github.com/astral-sh/uv/issues/1497 + uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu + + uv pip install --system ".[tests]" + - name: Lint with ruff + run: | + make lint + - name: Check codestyle + run: | + make check-codestyle + - name: Type check + run: | + make type + - name: Test with pytest + run: | + make pytest diff --git a/Makefile b/Makefile index 7ff0de0..ba91655 100644 --- a/Makefile +++ b/Makefile @@ -17,19 +17,19 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source + ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero --output-format=concise format: # Sort imports - isort ${LINT_PATHS} + ruff check --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports - isort --check ${LINT_PATHS} + ruff check --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} diff --git a/dqn_tutorial/dqn/dqn.py b/dqn_tutorial/dqn/dqn.py index fed52a7..b959dbc 100644 --- a/dqn_tutorial/dqn/dqn.py +++ b/dqn_tutorial/dqn/dqn.py @@ -303,7 +303,7 @@ def run_dqn( # eval_render_mode="human", # ) - # import flappy_bird_gymnasium # noqa: F401 + # import flappy_bird_gymnasium # # run_dqn( # env_id="FlappyBird-v0", diff --git a/dqn_tutorial/dqn/evaluation.py b/dqn_tutorial/dqn/evaluation.py index d9af859..233154a 100644 --- a/dqn_tutorial/dqn/evaluation.py +++ b/dqn_tutorial/dqn/evaluation.py @@ -1,10 +1,19 @@ +import warnings from pathlib import Path from typing import Optional import gymnasium as gym import numpy as np from gymnasium import spaces -from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder + +try: + from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder + + gym_v1 = False +except ImportError: + from gymnasium.wrappers import RecordVideo + + gym_v1 = True from dqn_tutorial.dqn.collect_data import epsilon_greedy_action_selection from dqn_tutorial.dqn.q_network import QNetwork @@ -25,19 +34,28 @@ def evaluate_policy( :param q_net: The Q-network to evaluate :param n_eval_episodes: The number of episodes to evaluate the policy on :param eval_exploration_rate: The exploration rate to use during evaluation + :param video_name: When set, the filename of the video to record. """ - assert isinstance(eval_env.action_space, spaces.Discrete) - # Setup video recorder video_recorder = None if video_name is not None and eval_env.render_mode == "rgb_array": video_path = Path(__file__).parent.parent.parent / "logs" / "videos" / video_name video_path.parent.mkdir(parents=True, exist_ok=True) - video_recorder = VideoRecorder( - env=eval_env, - base_path=str(video_path), - ) + if gym_v1: + # New gym recorder always wants to cut video into episodes, + # set video length big enough but not to inf (will cut into episodes) + # Silence warnings when the folder already exists + warnings.filterwarnings("ignore", category=UserWarning, module="gymnasium.wrappers.rendering") + eval_env = RecordVideo(eval_env, str(video_path.parent), step_trigger=lambda _: False, video_length=100_000) + eval_env.start_recording(video_name) + else: + video_recorder = VideoRecorder( + env=eval_env, + base_path=str(video_path), + ) + + assert isinstance(eval_env.action_space, spaces.Discrete) episode_returns = [] for _ in range(n_eval_episodes): @@ -70,6 +88,9 @@ def evaluate_policy( if video_recorder is not None: print(f"Saving video to {video_recorder.path}") video_recorder.close() + elif isinstance(eval_env, RecordVideo): + print(f"Saving video to {video_path}.mp4") + eval_env.close() # Print mean and std of the episode rewards print(f"Mean episode reward: {np.mean(episode_returns):.2f} +/- {np.std(episode_returns):.2f}") diff --git a/dqn_tutorial/fqi/fqi.py b/dqn_tutorial/fqi/fqi.py index 334f59d..bc1f1e5 100644 --- a/dqn_tutorial/fqi/fqi.py +++ b/dqn_tutorial/fqi/fqi.py @@ -4,6 +4,7 @@ by Ernst et al. and "Neural fitted Q iteration" by Martin Riedmiller. """ + from functools import partial from pathlib import Path from typing import Optional diff --git a/notebooks/1_fitted_q_iteration_fqi.ipynb b/notebooks/1_fitted_q_iteration_fqi.ipynb index 3de1c88..5905d79 100644 --- a/notebooks/1_fitted_q_iteration_fqi.ipynb +++ b/notebooks/1_fitted_q_iteration_fqi.ipynb @@ -622,7 +622,7 @@ "source": [ "import os\n", "\n", - "from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder\n", + "from gymnasium.wrappers import RecordVideo\n", "\n", "\n", "def evaluate(\n", @@ -636,14 +636,13 @@ " done = False\n", "\n", " # Setup video recorder\n", - " video_recorder = None\n", " if video_name is not None and env.render_mode == \"rgb_array\":\n", " os.makedirs(\"../logs/videos/\", exist_ok=True)\n", "\n", - " video_recorder = VideoRecorder(\n", - " env=env,\n", - " base_path=f\"../logs/videos/{video_name}\",\n", - " )\n", + " # New gym recorder always wants to cut video into episodes,\n", + " # set video length big enough but not to inf (will cut into episodes)\n", + " eval_env = RecordVideo(eval_env, \"../logs/videos\", step_trigger=lambda _: False, video_length=100_000)\n", + " eval_env.start_recording(video_name)\n", "\n", " current_obs, _ = env.reset()\n", " # Number of discrete actions\n", @@ -651,10 +650,6 @@ " assert isinstance(env.action_space, spaces.Discrete), \"FQI only support discrete actions\"\n", "\n", " while total_episodes < n_eval_episodes:\n", - " # Record video\n", - " if video_recorder is not None:\n", - " video_recorder.capture_frame()\n", - "\n", " ### YOUR CODE HERE\n", "\n", " # Retrieve the q-values for the current observation\n", @@ -678,9 +673,9 @@ " total_episodes += 1\n", " current_obs, _ = env.reset()\n", "\n", - " if video_recorder is not None:\n", - " print(f\"Saving video to {video_recorder.path}\")\n", - " video_recorder.close()\n", + " if isinstance(eval_env, RecordVideo):\n", + " print(f\"Saving video to ../logs/videos/{video_name}\")\n", + " eval_env.close()\n", "\n", " print(f\"Total reward = {np.mean(episode_returns):.2f} +/- {np.std(episode_returns):.2f}\")" ] diff --git a/notebooks/solutions/1_fitted_q_iteration_fqi.ipynb b/notebooks/solutions/1_fitted_q_iteration_fqi.ipynb index c67f20a..1c5add1 100644 --- a/notebooks/solutions/1_fitted_q_iteration_fqi.ipynb +++ b/notebooks/solutions/1_fitted_q_iteration_fqi.ipynb @@ -599,7 +599,7 @@ "source": [ "import os\n", "\n", - "from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder\n", + "from gymnasium.wrappers import RecordVideo\n", "\n", "\n", "def evaluate(\n", @@ -617,19 +617,16 @@ " if video_name is not None and env.render_mode == \"rgb_array\":\n", " os.makedirs(\"../logs/videos/\", exist_ok=True)\n", "\n", - " video_recorder = VideoRecorder(\n", - " env=env,\n", - " base_path=f\"../logs/videos/{video_name}\",\n", - " )\n", + " # New gym recorder always wants to cut video into episodes,\n", + " # set video length big enough but not to inf (will cut into episodes)\n", + " eval_env = RecordVideo(eval_env, \"../logs/videos\", step_trigger=lambda _: False, video_length=100_000)\n", + " eval_env.start_recording(video_name)\n", "\n", " obs, _ = env.reset()\n", " n_actions = int(env.action_space.n)\n", " assert isinstance(env.action_space, spaces.Discrete), \"FQI only support discrete actions\"\n", "\n", " while total_episodes < n_eval_episodes:\n", - " # Record video\n", - " if video_recorder is not None:\n", - " video_recorder.capture_frame()\n", "\n", " ### YOUR CODE HERE\n", "\n", @@ -660,9 +657,9 @@ " total_episodes += 1\n", " obs, _ = env.reset()\n", "\n", - " if video_recorder is not None:\n", - " print(f\"Saving video to {video_recorder.path}\")\n", - " video_recorder.close()\n", + " if isinstance(eval_env, RecordVideo):\n", + " print(f\"Saving video to ../logs/videos/{video_name}\")\n", + " eval_env.close()\n", "\n", " print(f\"Total reward = {np.mean(episode_returns):.2f} +/- {np.std(episode_returns):.2f}\")" ] diff --git a/pyproject.toml b/pyproject.toml index a05f18e..1be4fd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,9 +9,7 @@ packages = ["dqn_tutorial"] name = "dqn_tutorial" # version = "0.0.1", version is determined by setuptools_scm dynamic = ["version"] -authors = [ - { name="Antonin Raffin", email="antonin.raffin@dlr.de" }, -] +authors = [{ name = "Antonin Raffin", email = "antonin.raffin@dlr.de" }] description = "A small example package" readme = "README.md" requires-python = ">=3.8" @@ -21,14 +19,15 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] dependencies = [ "numpy", - "gymnasium[classic-control]>=0.28.1,<1.0", + "gymnasium[classic-control,other]>=0.29.1,<1.1.0", "scikit-learn", - "torch>=1.7.0", + "torch>=2.4.0", ] [tool.setuptools_scm] @@ -41,10 +40,8 @@ tests = [ "pytest-cov", # Type check "mypy", - # Lint code (flake8 replacement) + # Lint code and format "ruff", - # Sort imports - "isort>=5.0", # Reformat "black", ] @@ -58,25 +55,23 @@ tests = [ line-length = 127 # Assume Python 3.8 target-version = "py38" + +[tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ select = ["E", "F", "B", "UP", "C90", "RUF"] # Ignore explicit stacklevel` ignore = ["B028"] -[tool.ruff.per-file-ignores] +[tool.lint.ruff.per-file-ignores] -[tool.ruff.mccabe] +[tool.lint.ruff.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 [tool.black] line-length = 127 -[tool.isort] -profile = "black" -line_length = 127 -src_paths = ["dqn_tutorial"] [tool.mypy] ignore_missing_imports = true @@ -90,23 +85,21 @@ show_error_codes = true [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. -env = [ - "PYTHONHASHSEED=0" -] +# env = ["PYTHONHASHSEED=0"] -filterwarnings = [ -] +filterwarnings = [] markers = [ - "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" + "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", ] [tool.coverage.run] disable_warnings = ["couldnt-parse"] branch = false -omit = [ - "tests/*", - "dqn_tutorial/_version.py", -] +omit = ["tests/*", "dqn_tutorial/_version.py"] [tool.coverage.report] -exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + "if typing.TYPE_CHECKING:", +] diff --git a/tests/test_dqn.py b/tests/test_dqn.py index f11e822..c18baf4 100644 --- a/tests/test_dqn.py +++ b/tests/test_dqn.py @@ -1,3 +1,5 @@ +from pathlib import Path + import gymnasium as gym import numpy as np import torch as th @@ -5,6 +7,7 @@ from dqn_tutorial.dqn import QNetwork, ReplayBuffer, collect_one_step, linear_schedule from dqn_tutorial.dqn.dqn import run_dqn from dqn_tutorial.dqn.dqn_no_target import run_dqn_no_target +from dqn_tutorial.dqn.evaluation import evaluate_policy def test_q_net(): @@ -67,3 +70,17 @@ def test_dqn_run(): def test_dqn_notarget_run(): run_dqn_no_target(n_timesteps=1000, evaluation_interval=500) + + +def test_record_video(): + env = gym.make("CartPole-v1", render_mode="rgb_array") + q_net = QNetwork(env.observation_space, env.action_space) + + video_path = Path(__file__).parent.parent / "logs" / "videos" / "test_video_record.mp4" + if video_path.is_file(): + video_path.unlink() + + evaluate_policy(env, q_net, n_eval_episodes=5, video_name="test_video_record") + + assert video_path.is_file() + video_path.unlink()