From 48a06a7d197048866115fc0b5442e5f362de0e7e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 21 Sep 2023 17:21:34 +0100 Subject: [PATCH] add test Signed-off-by: Matteo Bettini --- test/conftest.py | 13 +++++++------ test/test_pettingzoo.py | 17 ++++++++++++----- test/test_vmas.py | 12 +++++++++--- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 26670db1..8205bb0a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,6 @@ import os import shutil +import uuid from pathlib import Path import pytest @@ -15,19 +16,19 @@ def pytest_sessionstart(session): Called after the Session object has been created and before performing collection and entering the run test loop. """ - folder_name = Path(os.getcwd()) - folder_name = folder_name / "tmp" - folder_name.mkdir(parents=False, exist_ok=True) + folder_name = Path(os.path.dirname(os.path.realpath(__file__))) + folder_name = folder_name / f"tmp_{str(uuid.uuid4())[:8]}" + folder_name.mkdir(parents=False, exist_ok=False) os.chdir(folder_name) + session._tmp_folder = folder_name def pytest_sessionfinish(session, exitstatus): """ Called after whole test run finished, right before returning the exit status to the system. - """ - folder_name = Path(os.getcwd()) / "tmp" - shutil.rmtree(folder_name) + #""" + shutil.rmtree(session._tmp_folder) @pytest.fixture diff --git a/test/test_pettingzoo.py b/test/test_pettingzoo.py index 56b6cf0f..a4a60df3 100644 --- a/test/test_pettingzoo.py +++ b/test/test_pettingzoo.py @@ -21,27 +21,34 @@ @pytest.mark.skipif(not _has_pettingzoo, reason="PettingZoo not found") class TestPettingzoo: @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) - @pytest.mark.parametrize("continuous", [True, False]) + @pytest.mark.parametrize("prefer_continuous", [True, False]) @pytest.mark.parametrize("task", list(PettingZooTask)) def test_all_algos_all_tasks( self, algo_config: AlgorithmConfig, task: Task, - continuous, + prefer_continuous, experiment_config, mlp_sequence_config, ): + # To not run the same test twice + if (prefer_continuous and not algo_config.supports_continuous_actions()) or ( + not prefer_continuous and not algo_config.supports_discrete_actions() + ): + return + + # To not run unsupported algo-task pairs if ( not task.supports_continuous_actions() - and (continuous or not algo_config.supports_discrete_actions()) + and not algo_config.supports_discrete_actions() ) or ( not task.supports_discrete_actions() - and (not continuous or not algo_config.supports_continuous_actions()) + and not algo_config.supports_continuous_actions() ): return task = task.get_from_yaml() - experiment_config.prefer_continuous_actions = continuous + experiment_config.prefer_continuous_actions = prefer_continuous experiment = Experiment( algorithm_config=algo_config.get_from_yaml(), model_config=mlp_sequence_config, diff --git a/test/test_vmas.py b/test/test_vmas.py index 07141a4c..f4a88f0e 100644 --- a/test/test_vmas.py +++ b/test/test_vmas.py @@ -21,18 +21,24 @@ @pytest.mark.skipif(not _has_vmas, reason="VMAS not found") class TestVmas: @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) - @pytest.mark.parametrize("continuous", [True, False]) + @pytest.mark.parametrize("prefer_continuous", [True, False]) @pytest.mark.parametrize("task", list(VmasTask)) def test_all_algos_all_tasks( self, algo_config: AlgorithmConfig, task: Task, - continuous, + prefer_continuous, experiment_config, mlp_sequence_config, ): + # To not run the same test twice + if (prefer_continuous and not algo_config.supports_continuous_actions()) or ( + not prefer_continuous and not algo_config.supports_discrete_actions() + ): + return + task = task.get_from_yaml() - experiment_config.prefer_continuous_actions = continuous + experiment_config.prefer_continuous_actions = prefer_continuous experiment = Experiment( algorithm_config=algo_config.get_from_yaml(), model_config=mlp_sequence_config,