Skip to content

Commit

Permalink
Add ability to save and load rational and basic agents to/from json.
Browse files Browse the repository at this point in the history
- add methods for saving and loading to components
- add save / load to all main agent factories
- saving and loading of associative memories

PiperOrigin-RevId: 690960131
Change-Id: Ie323681fe63fbc966d70bbd21bcaf9d09f06c48d
  • Loading branch information
vezhnick authored and copybara-github committed Oct 29, 2024
1 parent c8a45e5 commit eb06a1f
Show file tree
Hide file tree
Showing 21 changed files with 832 additions and 69 deletions.
8 changes: 8 additions & 0 deletions concordia/agents/entity_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def get_component(
component = self._context_components[name]
return cast(entity_component.ComponentT, component)

def get_act_component(self) -> entity_component.ActingComponent:
return self._act_component

def get_all_context_components(
self,
) -> Mapping[str, entity_component.ContextComponent]:
return types.MappingProxyType(self._context_components)

def _parallel_call_(
self,
method_name: str,
Expand Down
8 changes: 8 additions & 0 deletions concordia/agents/entity_agent_with_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""A modular entity agent using the new component system with side logging."""

from collections.abc import Mapping
import copy
import types
from typing import Any
from absl import logging
from concordia.agents import entity_agent
from concordia.associative_memory import formative_memories
from concordia.typing import agent
from concordia.typing import entity_component
from concordia.utils import measurements as measurements_lib
Expand All @@ -39,6 +41,7 @@ def __init__(
types.MappingProxyType({})
),
component_logging: measurements_lib.Measurements | None = None,
config: formative_memories.AgentConfig | None = None,
):
"""Initializes the agent.
Expand All @@ -56,6 +59,7 @@ def __init__(
None, a NoOpContextProcessor will be used.
context_components: The ContextComponents that will be used by the agent.
component_logging: The channels where components publish events.
config: The agent configuration, used for checkpointing and debug.
"""
super().__init__(agent_name=agent_name,
act_component=act_component,
Expand All @@ -75,6 +79,7 @@ def __init__(
on_error=lambda e: logging.error('Error in component logging: %s', e))
else:
self._channel_names = []
self._config = copy.deepcopy(config)

def _set_log(self, log: tuple[Any, ...]) -> None:
"""Set the logging object to return from get_last_log.
Expand All @@ -89,3 +94,6 @@ def _set_log(self, log: tuple[Any, ...]) -> None:
def get_last_log(self):
self._tick.on_next(None) # Trigger the logging.
return self._log

def get_config(self) -> formative_memories.AgentConfig | None:
return copy.deepcopy(self._config)
25 changes: 25 additions & 0 deletions concordia/associative_memory/associative_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import threading

from concordia.associative_memory import importance_function
from concordia.typing import entity_component
import numpy as np
import pandas as pd


_NUM_TO_RETRIEVE_TO_CONTEXTUALIZE_IMPORTANCE = 25


Expand Down Expand Up @@ -79,6 +81,29 @@ def __init__(
self._interval = clock_step_size
self._stored_hashes = set()

def get_state(self) -> entity_component.ComponentState:
"""Converts the AssociativeMemory to a dictionary."""

with self._memory_bank_lock:
output = {
'seed': self._seed,
'stored_hashes': list(self._stored_hashes),
'memory_bank': self._memory_bank.to_json(),
}
if self._interval:
output['interval'] = self._interval.total_seconds()
return output

def set_state(self, state: entity_component.ComponentState) -> None:
"""Sets the AssociativeMemory from a dictionary."""

with self._memory_bank_lock:
self._seed = state['seed']
self._stored_hashes = set(state['stored_hashes'])
self._memory_bank = pd.read_json(state['memory_bank'])
if 'interval' in state:
self._interval = datetime.timedelta(seconds=state['interval'])

def add(
self,
text: str,
Expand Down
19 changes: 17 additions & 2 deletions concordia/associative_memory/formative_memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""This is a factory for generating memories for concordia agents."""

from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Collection, Sequence
import dataclasses
import datetime
import logging
Expand Down Expand Up @@ -58,9 +58,24 @@ class AgentConfig:
specific_memories: str = ''
goal: str = ''
date_of_birth: datetime.datetime = DEFAULT_DOB
formative_ages: Iterable[int] = DEFAULT_FORMATIVE_AGES
formative_ages: Collection[int] = DEFAULT_FORMATIVE_AGES
extras: dict[str, Any] = dataclasses.field(default_factory=dict)

def to_dict(self) -> dict[str, Any]:
"""Converts the AgentConfig to a dictionary."""
result = dataclasses.asdict(self)
result['date_of_birth'] = self.date_of_birth.isoformat()
return result

@classmethod
def from_dict(cls, data: dict[str, Any]) -> 'AgentConfig':
"""Initializes an AgentConfig from a dictionary."""
date_of_birth = datetime.datetime.fromisoformat(
data['date_of_birth']
)
data = data | {'date_of_birth': date_of_birth}
return cls(**data)


class FormativeMemoryFactory:
"""Generator of formative memories."""
Expand Down
11 changes: 10 additions & 1 deletion concordia/components/agent/action_spec_ignored.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import abc
import threading
from typing import Final
from typing import Final, Any

from concordia.typing import entity as entity_lib
from concordia.typing import entity_component
from typing_extensions import override


class ActionSpecIgnored(
Expand Down Expand Up @@ -89,3 +90,11 @@ def get_named_component_pre_act_value(self, component_name: str) -> str:
"""Returns the pre-act value of a named component of the parent entity."""
return self.get_entity().get_component(
component_name, type_=ActionSpecIgnored).get_pre_act_value()

@override
def set_state(self, state: entity_component.ComponentState) -> Any:
return None

@override
def get_state(self) -> entity_component.ComponentState:
return {}
5 changes: 2 additions & 3 deletions concordia/components/agent/all_similar_memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Return all memories similar to a prompt and filter them for relevance.
"""
"""Return all memories similar to a prompt and filter them for relevance."""

from collections.abc import Mapping
import types
from typing import Mapping

from concordia.components.agent import action_spec_ignored
from concordia.components.agent import memory_component
Expand Down
8 changes: 8 additions & 0 deletions concordia/components/agent/concat_act_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,11 @@ def _log(self,
'Value': result,
'Prompt': prompt.view().text().splitlines(),
})

def get_state(self) -> entity_component.ComponentState:
"""Converts the component to a dictionary."""
return {}

def set_state(self, state: entity_component.ComponentState) -> None:
pass

8 changes: 8 additions & 0 deletions concordia/components/agent/memory_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def _check_phase(self) -> None:
'You can only access the memory outside of the `UPDATE` phase.'
)

def get_state(self) -> Mapping[str, Any]:
with self._lock:
return self._memory.get_state()

def set_state(self, state: Mapping[str, Any]) -> None:
with self._lock:
self._memory.set_state(state)

def retrieve(
self,
query: str = '',
Expand Down
12 changes: 12 additions & 0 deletions concordia/components/agent/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,15 @@ def _make_pre_act_value(self) -> str:
})

return result

def get_state(self) -> entity_component.ComponentState:
"""Converts the component to JSON data."""
with self._lock:
return {
'current_plan': self._current_plan,
}

def set_state(self, state: entity_component.ComponentState) -> None:
"""Sets the component state from JSON data."""
with self._lock:
self._current_plan = state['current_plan']
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def pre_act(
def update(self) -> None:
self._component.update()

def get_state(self) -> entity_component.ComponentState:
return self._component.get_state()

def set_state(self, state: entity_component.ComponentState) -> None:
self._component.set_state(state)


class Identity(QuestionOfQueryAssociatedMemories):
"""Identity component containing a few characteristics.
Expand Down
1 change: 1 addition & 0 deletions concordia/components/agent/question_of_recent_memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class AvailableOptionsPerception(QuestionOfRecentMemories):
"""This component answers the question 'what actions are available to me?'."""

def __init__(self, **kwargs):

super().__init__(
question=(
'Given the statements above, what actions are available to '
Expand Down
90 changes: 56 additions & 34 deletions concordia/contrib/components/agent/affect_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def __init__(
num_salient_to_retrieve: retrieve this many salient memories.
num_questions_to_consider: how many questions to ask self.
num_to_retrieve_per_question: how many memories to retrieve per question.
pre_act_key: Prefix to add to the output of the component when called
in `pre_act`.
pre_act_key: Prefix to add to the output of the component when called in
`pre_act`.
logging_channel: The channel to use for debug logging.
"""
super().__init__(pre_act_key)
Expand All @@ -91,35 +91,39 @@ def _make_pre_act_value(self) -> str:
for key, prefix in self._components.items()
])
salience_chain_of_thought = interactive_document.InteractiveDocument(
self._model)
self._model
)

query = f'salient event, period, feeling, or concept for {agent_name}'
timed_query = f'[{self._clock.now()}] {query}'

memory = self.get_entity().get_component(
self._memory_component_name,
type_=memory_component.MemoryComponent)
mem_retrieved = '\n'.join(
[mem.text for mem in memory.retrieve(
self._memory_component_name, type_=memory_component.MemoryComponent
)
mem_retrieved = '\n'.join([
mem.text
for mem in memory.retrieve(
query=timed_query,
scoring_fn=legacy_associative_memory.RetrieveAssociative(
use_recency=True, add_time=True
),
limit=self._num_salient_to_retrieve)]
)
limit=self._num_salient_to_retrieve,
)
])

question_list = []

questions = salience_chain_of_thought.open_question(
(
f'Recent feelings: {self._previous_pre_act_value} \n' +
f"{agent_name}'s relevant memory:\n" +
f'{mem_retrieved}\n' +
f'Current time: {self._clock.now()}\n' +
'\nGiven the thoughts and beliefs above, what are the ' +
f'{self._num_questions_to_consider} most salient high-level '+
f'questions that can be answered about what {agent_name} ' +
'might be feeling about the current moment?'),
f'Recent feelings: {self._previous_pre_act_value} \n'
+ f"{agent_name}'s relevant memory:\n"
+ f'{mem_retrieved}\n'
+ f'Current time: {self._clock.now()}\n'
+ '\nGiven the thoughts and beliefs above, what are the '
+ f'{self._num_questions_to_consider} most salient high-level '
+ f'questions that can be answered about what {agent_name} '
+ 'might be feeling about the current moment?'
),
answer_prefix='- ',
max_tokens=3000,
terminators=(),
Expand All @@ -128,25 +132,33 @@ def _make_pre_act_value(self) -> str:
question_related_mems = []
for question in questions:
question_list.append(question)
question_related_mems = [mem.text for mem in memory.retrieve(
query=agent_name,
scoring_fn=legacy_associative_memory.RetrieveAssociative(
use_recency=False, add_time=True
),
limit=self._num_to_retrieve_per_question)]
question_related_mems = [
mem.text
for mem in memory.retrieve(
query=agent_name,
scoring_fn=legacy_associative_memory.RetrieveAssociative(
use_recency=False, add_time=True
),
limit=self._num_to_retrieve_per_question,
)
]
insights = []
question_related_mems = '\n'.join(question_related_mems)

chain_of_thought = interactive_document.InteractiveDocument(self._model)
insight = chain_of_thought.open_question(
f'Selected memories:\n{question_related_mems}\n' +
f'Recent feelings: {self._previous_pre_act_value} \n\n' +
'New context:\n' + context + '\n' +
f'Current time: {self._clock.now()}\n' +
'What high-level insight can be inferred from the above ' +
f'statements about what {agent_name} might be feeling ' +
'in the current moment?',
max_tokens=2000, terminators=(),)
f'Selected memories:\n{question_related_mems}\n'
+ f'Recent feelings: {self._previous_pre_act_value} \n\n'
+ 'New context:\n'
+ context
+ '\n'
+ f'Current time: {self._clock.now()}\n'
+ 'What high-level insight can be inferred from the above '
+ f'statements about what {agent_name} might be feeling '
+ 'in the current moment?',
max_tokens=2000,
terminators=(),
)
insights.append(insight)

result = '\n'.join(insights)
Expand All @@ -157,9 +169,19 @@ def _make_pre_act_value(self) -> str:
'Key': self.get_pre_act_key(),
'Value': result,
'Salience chain of thought': (
salience_chain_of_thought.view().text().splitlines()),
'Chain of thought': (
chain_of_thought.view().text().splitlines()),
salience_chain_of_thought.view().text().splitlines()
),
'Chain of thought': chain_of_thought.view().text().splitlines(),
})

return result

def get_state(self) -> entity_component.ComponentState:
"""Converts the component to a dictionary."""
return {
'previous_pre_act_value': self._previous_pre_act_value,
}

def set_state(self, state: entity_component.ComponentState) -> None:

self._previous_pre_act_value = str(state['previous_pre_act_value'])
Loading

0 comments on commit eb06a1f

Please sign in to comment.