Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 21, 2023
1 parent 1678fdc commit 48a06a7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
13 changes: 7 additions & 6 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import uuid
from pathlib import Path

import pytest
Expand All @@ -15,19 +16,19 @@ def pytest_sessionstart(session):
Called after the Session object has been created and
before performing collection and entering the run test loop.
"""
folder_name = Path(os.getcwd())
folder_name = folder_name / "tmp"
folder_name.mkdir(parents=False, exist_ok=True)
folder_name = Path(os.path.dirname(os.path.realpath(__file__)))
folder_name = folder_name / f"tmp_{str(uuid.uuid4())[:8]}"
folder_name.mkdir(parents=False, exist_ok=False)
os.chdir(folder_name)
session._tmp_folder = folder_name


def pytest_sessionfinish(session, exitstatus):
"""
Called after whole test run finished, right before
returning the exit status to the system.
"""
folder_name = Path(os.getcwd()) / "tmp"
shutil.rmtree(folder_name)
#"""
shutil.rmtree(session._tmp_folder)


@pytest.fixture
Expand Down
17 changes: 12 additions & 5 deletions test/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,34 @@
@pytest.mark.skipif(not _has_pettingzoo, reason="PettingZoo not found")
class TestPettingzoo:
@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("continuous", [True, False])
@pytest.mark.parametrize("prefer_continuous", [True, False])
@pytest.mark.parametrize("task", list(PettingZooTask))
def test_all_algos_all_tasks(
self,
algo_config: AlgorithmConfig,
task: Task,
continuous,
prefer_continuous,
experiment_config,
mlp_sequence_config,
):
# To not run the same test twice
if (prefer_continuous and not algo_config.supports_continuous_actions()) or (
not prefer_continuous and not algo_config.supports_discrete_actions()
):
return

# To not run unsupported algo-task pairs
if (
not task.supports_continuous_actions()
and (continuous or not algo_config.supports_discrete_actions())
and not algo_config.supports_discrete_actions()
) or (
not task.supports_discrete_actions()
and (not continuous or not algo_config.supports_continuous_actions())
and not algo_config.supports_continuous_actions()
):
return

task = task.get_from_yaml()
experiment_config.prefer_continuous_actions = continuous
experiment_config.prefer_continuous_actions = prefer_continuous
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=mlp_sequence_config,
Expand Down
12 changes: 9 additions & 3 deletions test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@
@pytest.mark.skipif(not _has_vmas, reason="VMAS not found")
class TestVmas:
@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("continuous", [True, False])
@pytest.mark.parametrize("prefer_continuous", [True, False])
@pytest.mark.parametrize("task", list(VmasTask))
def test_all_algos_all_tasks(
self,
algo_config: AlgorithmConfig,
task: Task,
continuous,
prefer_continuous,
experiment_config,
mlp_sequence_config,
):
# To not run the same test twice
if (prefer_continuous and not algo_config.supports_continuous_actions()) or (
not prefer_continuous and not algo_config.supports_discrete_actions()
):
return

task = task.get_from_yaml()
experiment_config.prefer_continuous_actions = continuous
experiment_config.prefer_continuous_actions = prefer_continuous
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=mlp_sequence_config,
Expand Down

0 comments on commit 48a06a7

Please sign in to comment.