Skip to content

Commit

Permalink
search scope and summary search (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielchalef authored Oct 31, 2023
1 parent bd9070b commit ade1072
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 9 deletions.
22 changes: 20 additions & 2 deletions examples/memory_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def main() -> None:
)

# Naive wait for memory to be enriched and indexed
time.sleep(2.0)
time.sleep(5.0)

# Get Memory for session
print(f"\n3---getMemory for Session: {session_id}")
Expand All @@ -128,7 +128,7 @@ async def main() -> None:
search_payload = MemorySearchPayload(
text=query,
metadata={
"where": {"jsonpath": '$.system.entities[*] ? (@.Label == "LOC")'}
"where": {"jsonpath": '$.system.entities[*] ? (@.Label == "GPE")'}
},
)
print(f"\n4---searchMemory for Query: '{query}'")
Expand Down Expand Up @@ -159,6 +159,24 @@ async def main() -> None:
except NotFoundError:
print(f"Nothing found for Session {session_id}")

# Search Summary with MMR reranking
search_payload = MemorySearchPayload(
text=query,
search_scope="summary",
search_type="mmr",
mmr_lambda=0.5,
)
print(f"\n4---searchMemory for MMR Query: '{query}'")
try:
search_results = client.memory.search_memory(
session_id, search_payload, limit=3
)
for search_result in search_results:
message_content = search_result.summary
print(f"Search result: {message_content}")
except NotFoundError:
print("Nothing found for Session" + session_id)

# Delete Memory for session
print(f"\n5---deleteMemory for Session: {session_id}")
try:
Expand Down
20 changes: 19 additions & 1 deletion examples/memory_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def main() -> None:
print(f"Unable to add memory to session {session_id}. Error: {e}")

# Naive wait for memory to be enriched and indexed
time.sleep(2.0)
time.sleep(5.0)

# Get memory we just added
print(f"\n3---getMemory for Session: {session_id}")
Expand Down Expand Up @@ -143,6 +143,24 @@ def main() -> None:
except NotFoundError:
print("Nothing found for Session" + session_id)

# Search Summary with MMR reranking
search_payload = MemorySearchPayload(
text=query,
search_scope="summary",
search_type="mmr",
mmr_lambda=0.5,
)
print(f"\n4---searchMemory for MMR Query: '{query}'")
try:
search_results = client.memory.search_memory(
session_id, search_payload, limit=3
)
for search_result in search_results:
message_content = search_result.summary
print(f"Search result: {message_content}")
except NotFoundError:
print("Nothing found for Session" + session_id)

# Delete memory
print(f"Deleting memory for Session: {session_id}")
try:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zep-python"
version = "1.3.0"
version = "1.4.0"
description = "Zep: Fast, scalable building blocks for LLM apps. This is the Python client for the Zep service."
authors = ["Daniel Chalef <[email protected]>"]
readme = "README.md"
Expand Down
46 changes: 46 additions & 0 deletions tests/memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Memory,
MemorySearchPayload,
Message,
SearchScope,
Session,
)
from zep_python.utils import SearchType
Expand Down Expand Up @@ -264,6 +265,51 @@ async def test_asearch_memory_invalid_search_type(httpx_mock: HTTPXMock):
_ = await client.memory.asearch_memory(session_id, search_payload)


@pytest.mark.asyncio
async def test_asearch_memory_scope_summary(httpx_mock: HTTPXMock):
session_id = str(uuid4())

search_payload = MemorySearchPayload(
text="Test query",
metadata={"where": {"jsonpath": '$.system.entities[*] ? (@.Label == "DATE")'}},
search_scope=SearchScope.summary,
mmr_lambda=0.5,
)
mock_response = [
{
"summary": {
"uuid": "msg-uuid",
"content": "Test summary",
},
"dist": 0.9,
}
]

httpx_mock.add_response(status_code=200, json=mock_response)

async with ZepClient(base_url=API_BASE_URL) as client:
search_results = await client.memory.asearch_memory(session_id, search_payload)

assert len(search_results) == 1
assert search_results[0].summary.uuid == "msg-uuid"
assert search_results[0].summary.content == "Test summary"
assert search_results[0].dist == 0.9


@pytest.mark.asyncio
async def test_asearch_memory_invalid_search_scope(httpx_mock: HTTPXMock):
session_id = str(uuid4())

search_payload = MemorySearchPayload(
text="Test query",
search_scope="invalid",
)

with ZepClient(base_url=API_BASE_URL) as client:
with pytest.raises(ValueError):
_ = await client.memory.asearch_memory(session_id, search_payload)


@pytest.mark.asyncio
async def test_asearch_memory_no_payload(httpx_mock: HTTPXMock):
session_id = str(uuid4())
Expand Down
7 changes: 7 additions & 0 deletions zep_python/memory/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
MemorySearchPayload,
MemorySearchResult,
Message,
SearchScope,
Session,
Summary,
)
Expand Down Expand Up @@ -719,6 +720,9 @@ def search_memory(
if search_payload.search_type not in SearchType.__members__:
raise ValueError("search_type must be one of 'similarity' or 'mmr'")

if search_payload.search_scope not in SearchScope.__members__:
raise ValueError("search_scope must be one of 'messages' or 'summary'")

params = {"limit": limit} if limit is not None else {}
response = self.client.post(
f"/sessions/{session_id}/search",
Expand Down Expand Up @@ -770,6 +774,9 @@ async def asearch_memory(
if search_payload.search_type not in SearchType.__members__:
raise ValueError("search_type must be one of 'similarity' or 'mmr'")

if search_payload.search_scope not in SearchScope.__members__:
raise ValueError("search_scope must be one of 'messages' or 'summary'")

params = {"limit": limit} if limit is not None else {}
response = await self.aclient.post(
f"/sessions/{session_id}/search",
Expand Down
19 changes: 15 additions & 4 deletions zep_python/memory/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional

if TYPE_CHECKING:
Expand All @@ -11,6 +12,11 @@
from pydantic import BaseModel, Field


class SearchScope(str, Enum):
messages = "messages"
summary = "summary"


class Session(BaseModel):
"""
Represents a session object with a unique identifier, metadata,
Expand Down Expand Up @@ -177,6 +183,9 @@ class MemorySearchPayload(BaseModel):
Metadata associated with the search query.
text : str
The text of the search query.
search_scope : Optional[str]
Search over messages or summaries. Defaults to "messages".
Must be one of "messages" or "summary".
search_type : Optional[str]
The type of search to perform. Defaults to "similarity".
Must be one of "similarity" or "mmr".
Expand All @@ -186,6 +195,7 @@ class MemorySearchPayload(BaseModel):

text: Optional[str] = Field(default=None)
metadata: Optional[Dict[str, Any]] = Field(default=None)
search_scope: Optional[str] = Field(default="messages")
search_type: Optional[str] = Field(default="similarity")
mmr_lambda: Optional[float] = Field(default=None)

Expand All @@ -197,16 +207,17 @@ class MemorySearchResult(BaseModel):
Attributes
----------
message : Optional[Dict[str, Any]]
The message associated with the search result.
The message matched by search.
summary : Optional[Summary]
The summary matched by search.
metadata : Optional[Dict[str, Any]]
Metadata associated with the search result.
summary : Optional[str]
The summary of the search result.
dist : Optional[float]
The distance metric of the search result.
"""

# TODO: Legacy bug. message should be a Message object.
message: Optional[Dict[str, Any]] = None
summary: Optional[Summary] = None
metadata: Optional[Dict[str, Any]] = None
summary: Optional[str] = None
dist: Optional[float] = None
2 changes: 1 addition & 1 deletion zep_python/zep_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
API_BASE_PATH = "/api/v1"
API_TIMEOUT = 10

MINIMUM_SERVER_VERSION = "0.16.0"
MINIMUM_SERVER_VERSION = "0.17.0"


class ZepClient:
Expand Down

0 comments on commit ade1072

Please sign in to comment.