Skip to content

Commit

Permalink
Fix time serialization in associative_memory
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696147310
Change-Id: Ia08e7e64b18d2ac188100286de7128f9611ec73e
  • Loading branch information
jzleibo authored and copybara-github committed Nov 13, 2024
1 parent 62ad3d0 commit 56807c9
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions concordia/associative_memory/associative_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,14 @@ def get_state(self) -> entity_component.ComponentState:
"""Converts the AssociativeMemory to a dictionary."""

with self._memory_bank_lock:
serialized_times = self._memory_bank['time'].apply(
lambda x: x.strftime('[%d-%b-%Y-%H:%M:%S]')
).tolist()
output = {
'seed': self._seed,
'stored_hashes': list(self._stored_hashes),
'memory_bank': self._memory_bank.to_json(),
'time': serialized_times,
}
if self._interval:
output['interval'] = self._interval.total_seconds()
Expand All @@ -101,6 +105,10 @@ def set_state(self, state: entity_component.ComponentState) -> None:
self._seed = state['seed']
self._stored_hashes = set(state['stored_hashes'])
self._memory_bank = pd.read_json(state['memory_bank'])
self._memory_bank['time'] = [
datetime.datetime.strptime(t, '[%d-%b-%Y-%H:%M:%S]')
for t in state['time']
]
if 'interval' in state:
self._interval = datetime.timedelta(seconds=state['interval'])

Expand Down

0 comments on commit 56807c9

Please sign in to comment.