Skip to content

Commit

Permalink
Fix environment factory test (pytype issue)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696216987
Change-Id: Ic3e7c32f2bc7c326966f70f7674dd8cc2ad850a9
  • Loading branch information
jzleibo authored and copybara-github committed Nov 13, 2024
1 parent 0151181 commit 3a7cc99
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 54 deletions.
4 changes: 2 additions & 2 deletions concordia/environment/scenes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from concordia.typing import clock as game_clock
from concordia.typing import logging as logging_lib
from concordia.typing import scene as scene_lib
from concordia.utils import helper_functions
from concordia.utils import json as json_lib


def _get_interscene_messages(
Expand Down Expand Up @@ -145,7 +145,7 @@ def run_scenes(
serialized_agents = {}
for participant in participants:
serialized_agents = {}
json_representation = helper_functions.save_to_json(participant)
json_representation = json_lib.save_to_json(participant)
serialized_agents[participant.name] = json_representation

if compute_metrics is not None:
Expand Down
9 changes: 4 additions & 5 deletions concordia/factory/environment/basic_game_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import operator

from concordia import components as generic_components
from concordia.agents import deprecated_agent
from concordia.agents import entity_agent
from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import blank_memories
from concordia.associative_memory import importance_function
Expand All @@ -42,7 +41,7 @@ def build_game_master(
embedder: Callable[[str], np.ndarray],
importance_model: importance_function.ImportanceModel,
clock: game_clock.MultiIntervalClock,
players: Sequence[deprecated_agent.BasicAgent],
players: Sequence[entity_agent_with_logging.EntityAgentWithLogging],
shared_memories: Sequence[str],
shared_context: str,
blank_memory_factory: blank_memories.MemoryFactory,
Expand Down Expand Up @@ -191,7 +190,7 @@ def build_decision_scene_game_master(
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.MultiIntervalClock,
players: Sequence[deprecated_agent.BasicAgent],
players: Sequence[entity_agent_with_logging.EntityAgentWithLogging],
decision_action_spec: agent_lib.ActionSpec,
payoffs: gm_components.schelling_diagram_payoffs.SchellingPayoffs,
verbose: bool = False,
Expand Down Expand Up @@ -289,7 +288,7 @@ def create_html_log(
def run_simulation(
*,
model: language_model.LanguageModel,
players: Sequence[deprecated_agent.BasicAgent | entity_agent.EntityAgent],
players: Sequence[entity_agent_with_logging.EntityAgentWithLogging],
primary_environment: game_master.GameMaster,
clock: game_clock.MultiIntervalClock,
scenes: Sequence[scene_lib.SceneSpec],
Expand Down
14 changes: 9 additions & 5 deletions concordia/factory/environment/factories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

from absl.testing import absltest
from absl.testing import parameterized
from concordia.agents import deprecated_agent
from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import blank_memories
from concordia.associative_memory import formative_memories
from concordia.associative_memory import importance_function
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.environment import game_master
from concordia.factory.environment import basic_game_master
from concordia.language_model import no_language_model
Expand Down Expand Up @@ -63,12 +64,15 @@ def test_give_me_a_name(self, environment_name: str):
importance=importance_model_gm.importance,
clock_now=clock.now,
)
player_a = deprecated_agent.BasicAgent(
act_component = agent_components.concat_act_component.ConcatActComponent(
model=model,
agent_name='Rakshit',
clock=clock,
components=[],
update_interval=datetime.timedelta(hours=1),
component_order=[],
)
player_a = entity_agent_with_logging.EntityAgentWithLogging(
agent_name='Rakshit',
act_component=act_component,
context_components={},
)

players = [player_a]
Expand Down
42 changes: 0 additions & 42 deletions concordia/utils/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@
from collections.abc import Iterable, Sequence
import datetime
import functools
import json

from concordia.agents import entity_agent_with_logging
from concordia.document import interactive_document
from concordia.language_model import language_model
from concordia.typing import component
from concordia.typing import entity_component
from concordia.utils import concurrency


Expand Down Expand Up @@ -147,42 +144,3 @@ def apply_recursively(
getattr(parent_component, function_name)()
else:
getattr(parent_component, function_name)(function_arg)


def save_to_json(
agent: entity_agent_with_logging.EntityAgentWithLogging,
) -> str:
"""Saves an agent to JSON data.
This function saves the agent's state to a JSON string, which can be loaded
afterwards with `rebuild_from_json`. The JSON data
includes the state of the agent's context components, act component, memory,
agent name and the initial config. The clock, model and embedder are not
saved and will have to be provided when the agent is rebuilt. The agent must
be in the `READY` phase to be saved.
Args:
agent: The agent to save.
Returns:
A JSON string representing the agent's state.
Raises:
ValueError: If the agent is not in the READY phase.
"""

if agent.get_phase() != entity_component.Phase.READY:
raise ValueError('The agent must be in the `READY` phase to be saved.')

data = {
component_name: agent.get_component(component_name).get_state()
for component_name in agent.get_all_context_components()
}

data['act_component'] = agent.get_act_component().get_state()

config = agent.get_config()
if config is not None:
data['agent_config'] = config.to_dict()

return json.dumps(data)
59 changes: 59 additions & 0 deletions concordia/utils/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Json helper functions."""

import json

from concordia.agents import entity_agent_with_logging
from concordia.typing import entity_component


def save_to_json(
agent: entity_agent_with_logging.EntityAgentWithLogging,
) -> str:
"""Saves an agent to JSON data.
This function saves the agent's state to a JSON string, which can be loaded
afterwards with `rebuild_from_json`. The JSON data
includes the state of the agent's context components, act component, memory,
agent name and the initial config. The clock, model and embedder are not
saved and will have to be provided when the agent is rebuilt. The agent must
be in the `READY` phase to be saved.
Args:
agent: The agent to save.
Returns:
A JSON string representing the agent's state.
Raises:
ValueError: If the agent is not in the READY phase.
"""

if agent.get_phase() != entity_component.Phase.READY:
raise ValueError('The agent must be in the `READY` phase to be saved.')

data = {
component_name: agent.get_component(component_name).get_state()
for component_name in agent.get_all_context_components()
}

data['act_component'] = agent.get_act_component().get_state()

config = agent.get_config()
if config is not None:
data['agent_config'] = config.to_dict()

return json.dumps(data)

0 comments on commit 3a7cc99

Please sign in to comment.