diff --git a/hive/runners/utils.py b/hive/runners/utils.py index 2eac3796..bd02ef32 100644 --- a/hive/runners/utils.py +++ b/hive/runners/utils.py @@ -227,9 +227,9 @@ def get_stacked_state(self, agent, observation): if self._stack_size == 1: return observation + while len(self._previous_observations[agent.id]) < self._stack_size - 1: self._previous_observations[agent.id].append(zeros_like(observation)) - stacked_observation = concatenate( list(self._previous_observations[agent.id]) + [observation] ) diff --git a/tests/hive/agents/test_dqn.py b/tests/hive/agents/test_dqn.py index e91615cd..2f9b5f77 100644 --- a/tests/hive/agents/test_dqn.py +++ b/tests/hive/agents/test_dqn.py @@ -14,6 +14,7 @@ from hive.replays import SimpleReplayBuffer from hive.utils import schedule +import pytest_lazyfixture ##Added this line @pytest.fixture def env_spec(): @@ -25,11 +26,11 @@ def env_spec(): """ -@pytest.fixture( +@pytest.fixture( ##Modified this params=[ - pytest.lazy_fixture("xxxx_agent_with_mock_optimizer"), - pytest.lazy_fixture("dxxx_agent_with_mock_optimizer"), - pytest.lazy_fixture("xdxx_agent_with_mock_optimizer"), + pytest_lazyfixture.lazy_fixture("xxxx_agent_with_mock_optimizer"), + pytest_lazyfixture.lazy_fixture("dxxx_agent_with_mock_optimizer"), + pytest_lazyfixture.lazy_fixture("xdxx_agent_with_mock_optimizer"), ] ) def agent_with_mock_optimizer(request): diff --git a/tests/hive/utils/transition_info/test_transition_info.py b/tests/hive/utils/transition_info/test_transition_info.py new file mode 100644 index 00000000..ae319943 --- /dev/null +++ b/tests/hive/utils/transition_info/test_transition_info.py @@ -0,0 +1,91 @@ +import os +import sys +from argparse import Namespace +from unittest.mock import patch + +import pytest + +import hive +from hive.runners.utils import load_config +from hive.runners.utils import TransitionInfo +from hive.agents.dqn import DQNAgent +from hive.agents.qnets.mlp import MLPNetwork +import gym +import numpy as np + +@pytest.fixture() +def args(): + return Namespace( + config="tests/hive/utils/transition_info/test_transition_info_config.yml", + agent_config=None, + ) + +@pytest.fixture() +def transition_info(args, tmpdir): + config = load_config( + args.config, + agent_config=args.agent_config, + ) + config["save_dir"] = os.path.join(tmpdir, config["save_dir"]) + env = gym.make('CartPole-v0') + agent_config = config['agent'] + + agent0 = DQNAgent(observation_space = env.observation_space, + action_space = env.action_space, + representation_net = MLPNetwork, + id = 0 + ) + agent1 = DQNAgent(observation_space = env.observation_space, + action_space = env.action_space, + representation_net = MLPNetwork, + id = 1 + ) + + agents = [agent0, agent1] + stack_size = 5 + t_info = TransitionInfo(agents, stack_size) + return t_info, agents, config + +def test_start_agent(transition_info): + t_info, agents, config = transition_info + t_info.start_agent(agents[0]) + assert t_info._started[agents[0].id] == True + +def test_is_started(transition_info): + t_info, agents, config = transition_info + t_info.start_agent(agents[0]) + assert t_info.is_started(agents[0]) == True + assert t_info.is_started(agents[1]) == False + +def test_update_reward(transition_info): + t_info, agents, config = transition_info + t_info.start_agent(agents[0]) + t_info.update_reward(agents[0], 1.) + assert t_info._transitions[t_info._agent_ids[0]]["reward"] == 1. + +def test_update_all_rewards(transition_info): + t_info, agents, config = transition_info + rewards = [1., 2.] + t_info.update_all_rewards(rewards) + assert t_info._transitions[t_info._agent_ids[0]]["reward"] == 1. + assert t_info._transitions[t_info._agent_ids[1]]["reward"] == 2. + +def test_get_info(transition_info): + t_info, agents, config = transition_info + info = t_info.get_info(agents[0], terminated = True, truncated = True) + assert info == {'reward':0, 'truncated':True, 'terminated' : True} + assert t_info._transitions['0'] == {"reward": 0.0} + +def test_record_info(transition_info): + t_info, agents, config = transition_info + info = {'observation': 2, 'reward':1} + t_info.record_info(agents[0], info) + assert t_info._transitions['0'] == {'observation': 2, 'reward':1} + assert t_info._previous_observations['0'][-1] == 2 + +def test_get_stacked_state(transition_info): + t_info, agents, config = transition_info + observation = 2 + t_info._previous_observations[agents[0].id].append(3) + stacked_observation = t_info.get_stacked_state(agents[0], observation) + assert list(stacked_observation) == [3,0,2] diff --git a/tests/hive/utils/transition_info/test_transition_info_config.yml b/tests/hive/utils/transition_info/test_transition_info_config.yml new file mode 100644 index 00000000..0705cc0b --- /dev/null +++ b/tests/hive/utils/transition_info/test_transition_info_config.yml @@ -0,0 +1,54 @@ +# General training loop config +run_name: &run_name 'dqn-metrics' +train_steps: 450 +test_frequency: 10 +test_num_episodes: 1 +self_play: False +num_agents: 2 +saving_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 15000 +save_dir: 'experiment' + +environment: + name: 'GymEnv' + kwargs: + env_name: 'CartPole-v0' + +# List of agents for the experiment. In single agent, only the first agent in +# the list is used. +agent: + name: 'DQNAgent' + kwargs: + representation_net: + name: 'MLPNetwork' + kwargs: + hidden_units: [256, 256] + optimizer_fn: + name: 'Adam' + kwargs: {} + id: 0 + replay_buffer: + name: 'CircularReplayBuffer' + kwargs: + capacity: 10000 + observation_dtype: 'np.float32' + discount_rate: .99 + target_net_update_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 100 + epsilon_schedule: + name: 'ConstantSchedule' + kwargs: + value: .01 + min_replay_history: 500 + batch_size: 128 + device: 'cpu' + log_frequency: 100 +