Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix temporal invalidation unit tests #23

Merged
merged 15 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/utils/maintenance/graph_data_operations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import logging
from datetime import datetime, timezone
from typing import LiteralString

from neo4j import AsyncDriver
from typing_extensions import LiteralString
paul-paliychuk marked this conversation as resolved.
Show resolved Hide resolved

from core.nodes import EpisodicNode

Expand Down
79 changes: 74 additions & 5 deletions tests/utils/maintenance/test_temporal_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from core.edges import EntityEdge
from core.nodes import EntityNode
from core.nodes import EntityNode, EpisodicNode
from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation,
prepare_invalidation_context,
Expand Down Expand Up @@ -114,7 +114,6 @@ def test_prepare_edges_for_invalidation_missing_nodes():


def test_prepare_invalidation_context():
# Create test data
now = datetime.now()

# Create nodes
Expand Down Expand Up @@ -148,15 +147,49 @@ def test_prepare_invalidation_context():
existing_edges = [existing_edge]
new_edges = [new_edge]

# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode 1',
content='This is the content of previous episode 1.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode 1 for unit testing',
),
EpisodicNode(
name='Previous Episode 2',
content='This is the content of previous episode 2.',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source='test',
source_description='Test previous episode 2 for unit testing',
),
]

# Call the function
result = prepare_invalidation_context(existing_edges, new_edges)
result = prepare_invalidation_context(
existing_edges, new_edges, current_episode, previous_episodes
)

# Assert the result
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 1
assert len(result['new_edges']) == 1
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 2

# Check the format of the existing edge
existing_edge_str = result['existing_edges'][0]
Expand All @@ -176,12 +209,25 @@ def test_prepare_invalidation_context():


def test_prepare_invalidation_context_empty_input():
result = prepare_invalidation_context([], [])
now = datetime.now()
current_episode = EpisodicNode(
name='Current Episode',
content='Empty episode',
created_at=now,
valid_at=now,
source='test',
source_description='Test empty episode for unit testing',
)
result = prepare_invalidation_context([], [], current_episode, [])
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 0
assert len(result['new_edges']) == 0
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 0


def test_prepare_invalidation_context_sorting():
Expand Down Expand Up @@ -215,13 +261,36 @@ def test_prepare_invalidation_context_sorting():
# Prepare test input
existing_edges = [edge_with_nodes1, edge_with_nodes2]

# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='This is the content of a previous episode.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode for unit testing',
),
]

# Call the function
result = prepare_invalidation_context(existing_edges, [])
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)

# Assert the result
assert len(result['existing_edges']) == 2
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 1
assert result['previous_episodes'][0] == previous_episodes[0].content


# Run the tests
Expand Down
39 changes: 32 additions & 7 deletions tests/utils/maintenance/test_temporal_operations_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from core.edges import EntityEdge
from core.llm_client import LLMConfig, OpenAIClient
from core.nodes import EntityNode
from core.nodes import EntityNode, EpisodicNode
from core.utils.maintenance.temporal_operations import (
invalidate_edges,
)
Expand All @@ -24,7 +24,6 @@ def setup_llm_client():
)


# Helper function to create test data
def create_test_data():
now = datetime.now()

Expand Down Expand Up @@ -53,15 +52,39 @@ def create_test_data():
existing_edge = (node1, edge1, node2)
new_edge = (node1, edge2, node2)

return existing_edge, new_edge
# Create current episode
current_episode = EpisodicNode(
name='Current Episode',
content='Alice now dislikes Bob',
created_at=now,
valid_at=now,
source='test',
source_description='Test episode for unit testing',
)

# Create previous episodes
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='Alice liked Bob',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source_description='Test previous episode for unit testing',
)
]

return existing_edge, new_edge, current_episode, previous_episodes


@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges():
existing_edge, new_edge = create_test_data()
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()

invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [new_edge])
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
)

assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge[1].uuid
Expand All @@ -71,9 +94,11 @@ async def test_invalidate_edges():
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_invalidation():
existing_edge, _ = create_test_data()
existing_edge, _, current_episode, previous_episodes = create_test_data()

invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [])
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
)

assert len(invalidated_edges) == 0

Expand Down
Loading