Skip to content

Commit

Permalink
Add observe_and_summarize supporting agent
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700005751
Change-Id: I38008ecc2a08f328547d0456740745ef46fa0093
  • Loading branch information
jzleibo authored and copybara-github committed Nov 25, 2024
1 parent 7c23be2 commit ba84894
Showing 1 changed file with 141 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An Agent Factory."""

import datetime

from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.contrib.components.agent import observations_since_last_update
from concordia.contrib.components.agent import situation_representation_via_narrative
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.utils import measurements as measurements_lib


def _get_class_name(object_: object) -> str:
return object_.__class__.__name__


def build_agent(
*,
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.MultiIntervalClock,
update_time_interval: datetime.timedelta | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
"""Build an agent.
Args:
config: The agent config to use.
model: The language model to use.
memory: The agent's memory object.
clock: The clock to use.
update_time_interval: Agent calls update every time this interval passes.
Returns:
An agent.
"""
del update_time_interval
if config.extras.get('main_character', False):
raise ValueError('This function is meant for a supporting character '
'but it was called on a main character.')

agent_name = config.name

raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

measurements = measurements_lib.Measurements()
instructions = agent_components.instructions.Instructions(
agent_name=agent_name,
logging_channel=measurements.get_channel('Instructions').on_next,
)

time_display = agent_components.report_function.ReportFunction(
function=clock.current_time_interval_str,
pre_act_key='\nCurrent time',
logging_channel=measurements.get_channel('TimeDisplay').on_next,
)

observation_label = '\nObservation'
observation = observations_since_last_update.ObservationsSinceLastUpdate(
model=model,
clock_now=clock.now,
pre_act_key=observation_label,
logging_channel=measurements.get_channel(
'ObservationsSinceLastUpdate').on_next,
)

situation_representation_label = (
f'\nQuestion: What situation is {agent_name} in right now?\nAnswer')
situation_representation = (
situation_representation_via_narrative.SituationRepresentation(
model=model,
clock_now=clock.now,
pre_act_key=situation_representation_label,
logging_channel=measurements.get_channel(
'SituationRepresentation'
).on_next,
)
)

if config.goal:
goal_label = '\nOverarching goal'
overarching_goal = agent_components.constant.Constant(
state=config.goal,
pre_act_key=goal_label,
logging_channel=measurements.get_channel(goal_label).on_next)
else:
goal_label = None
overarching_goal = None

entity_components = (
# Components that provide pre_act context.
instructions,
time_display,
situation_representation,
observation,
)
components_of_agent = {_get_class_name(component): component
for component in entity_components}
components_of_agent[
agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
agent_components.memory_component.MemoryComponent(raw_memory))

component_order = list(components_of_agent.keys())
if overarching_goal is not None:
components_of_agent[goal_label] = overarching_goal
# Place goal after the instructions.
component_order.insert(1, goal_label)

act_component = agent_components.concat_act_component.ConcatActComponent(
model=model,
clock=clock,
component_order=component_order,
logging_channel=measurements.get_channel('ActComponent').on_next,
)

agent = entity_agent_with_logging.EntityAgentWithLogging(
agent_name=agent_name,
act_component=act_component,
context_components=components_of_agent,
component_logging=measurements,
)

return agent

0 comments on commit ba84894

Please sign in to comment.