From 3a7cc996ebe8f9f40c30fb096ca8a38ac983b77f Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Wed, 13 Nov 2024 11:47:55 -0800 Subject: [PATCH] Fix environment factory test (pytype issue) PiperOrigin-RevId: 696216987 Change-Id: Ic3e7c32f2bc7c326966f70f7674dd8cc2ad850a9 --- concordia/environment/scenes/runner.py | 4 +- .../factory/environment/basic_game_master.py | 9 ++- .../factory/environment/factories_test.py | 14 +++-- concordia/utils/helper_functions.py | 42 ------------- concordia/utils/json.py | 59 +++++++++++++++++++ 5 files changed, 74 insertions(+), 54 deletions(-) create mode 100644 concordia/utils/json.py diff --git a/concordia/environment/scenes/runner.py b/concordia/environment/scenes/runner.py index b3bdfed1..1ce44017 100644 --- a/concordia/environment/scenes/runner.py +++ b/concordia/environment/scenes/runner.py @@ -21,7 +21,7 @@ from concordia.typing import clock as game_clock from concordia.typing import logging as logging_lib from concordia.typing import scene as scene_lib -from concordia.utils import helper_functions +from concordia.utils import json as json_lib def _get_interscene_messages( @@ -145,7 +145,7 @@ def run_scenes( serialized_agents = {} for participant in participants: serialized_agents = {} - json_representation = helper_functions.save_to_json(participant) + json_representation = json_lib.save_to_json(participant) serialized_agents[participant.name] = json_representation if compute_metrics is not None: diff --git a/concordia/factory/environment/basic_game_master.py b/concordia/factory/environment/basic_game_master.py index 570934a9..44957149 100644 --- a/concordia/factory/environment/basic_game_master.py +++ b/concordia/factory/environment/basic_game_master.py @@ -18,8 +18,7 @@ import operator from concordia import components as generic_components -from concordia.agents import deprecated_agent -from concordia.agents import entity_agent +from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory from concordia.associative_memory import blank_memories from concordia.associative_memory import importance_function @@ -42,7 +41,7 @@ def build_game_master( embedder: Callable[[str], np.ndarray], importance_model: importance_function.ImportanceModel, clock: game_clock.MultiIntervalClock, - players: Sequence[deprecated_agent.BasicAgent], + players: Sequence[entity_agent_with_logging.EntityAgentWithLogging], shared_memories: Sequence[str], shared_context: str, blank_memory_factory: blank_memories.MemoryFactory, @@ -191,7 +190,7 @@ def build_decision_scene_game_master( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, clock: game_clock.MultiIntervalClock, - players: Sequence[deprecated_agent.BasicAgent], + players: Sequence[entity_agent_with_logging.EntityAgentWithLogging], decision_action_spec: agent_lib.ActionSpec, payoffs: gm_components.schelling_diagram_payoffs.SchellingPayoffs, verbose: bool = False, @@ -289,7 +288,7 @@ def create_html_log( def run_simulation( *, model: language_model.LanguageModel, - players: Sequence[deprecated_agent.BasicAgent | entity_agent.EntityAgent], + players: Sequence[entity_agent_with_logging.EntityAgentWithLogging], primary_environment: game_master.GameMaster, clock: game_clock.MultiIntervalClock, scenes: Sequence[scene_lib.SceneSpec], diff --git a/concordia/factory/environment/factories_test.py b/concordia/factory/environment/factories_test.py index 72700091..b015b4af 100644 --- a/concordia/factory/environment/factories_test.py +++ b/concordia/factory/environment/factories_test.py @@ -19,12 +19,13 @@ from absl.testing import absltest from absl.testing import parameterized -from concordia.agents import deprecated_agent +from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory from concordia.associative_memory import blank_memories from concordia.associative_memory import formative_memories from concordia.associative_memory import importance_function from concordia.clocks import game_clock +from concordia.components import agent as agent_components from concordia.environment import game_master from concordia.factory.environment import basic_game_master from concordia.language_model import no_language_model @@ -63,12 +64,15 @@ def test_give_me_a_name(self, environment_name: str): importance=importance_model_gm.importance, clock_now=clock.now, ) - player_a = deprecated_agent.BasicAgent( + act_component = agent_components.concat_act_component.ConcatActComponent( model=model, - agent_name='Rakshit', clock=clock, - components=[], - update_interval=datetime.timedelta(hours=1), + component_order=[], + ) + player_a = entity_agent_with_logging.EntityAgentWithLogging( + agent_name='Rakshit', + act_component=act_component, + context_components={}, ) players = [player_a] diff --git a/concordia/utils/helper_functions.py b/concordia/utils/helper_functions.py index a28e67ed..920f4d8c 100644 --- a/concordia/utils/helper_functions.py +++ b/concordia/utils/helper_functions.py @@ -17,13 +17,10 @@ from collections.abc import Iterable, Sequence import datetime import functools -import json -from concordia.agents import entity_agent_with_logging from concordia.document import interactive_document from concordia.language_model import language_model from concordia.typing import component -from concordia.typing import entity_component from concordia.utils import concurrency @@ -147,42 +144,3 @@ def apply_recursively( getattr(parent_component, function_name)() else: getattr(parent_component, function_name)(function_arg) - - -def save_to_json( - agent: entity_agent_with_logging.EntityAgentWithLogging, -) -> str: - """Saves an agent to JSON data. - - This function saves the agent's state to a JSON string, which can be loaded - afterwards with `rebuild_from_json`. The JSON data - includes the state of the agent's context components, act component, memory, - agent name and the initial config. The clock, model and embedder are not - saved and will have to be provided when the agent is rebuilt. The agent must - be in the `READY` phase to be saved. - - Args: - agent: The agent to save. - - Returns: - A JSON string representing the agent's state. - - Raises: - ValueError: If the agent is not in the READY phase. - """ - - if agent.get_phase() != entity_component.Phase.READY: - raise ValueError('The agent must be in the `READY` phase to be saved.') - - data = { - component_name: agent.get_component(component_name).get_state() - for component_name in agent.get_all_context_components() - } - - data['act_component'] = agent.get_act_component().get_state() - - config = agent.get_config() - if config is not None: - data['agent_config'] = config.to_dict() - - return json.dumps(data) diff --git a/concordia/utils/json.py b/concordia/utils/json.py new file mode 100644 index 00000000..1534a259 --- /dev/null +++ b/concordia/utils/json.py @@ -0,0 +1,59 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Json helper functions.""" + +import json + +from concordia.agents import entity_agent_with_logging +from concordia.typing import entity_component + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data)