Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Nov 24, 2024
1 parent 1ca6dac commit 4304c58
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 24 deletions.
3 changes: 1 addition & 2 deletions examples/podcast/transcript_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import re
from datetime import datetime, timedelta, timezone
from typing import List

from pydantic import BaseModel

Expand Down Expand Up @@ -36,7 +35,7 @@ def parse_timestamp(timestamp: str) -> timedelta:
return timedelta() # Return 0 duration if parsing fails


def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]:
def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]:
with open(file_path) as file:
content = file.read()

Expand Down
3 changes: 1 addition & 2 deletions graphiti_core/cross_encoder/bge_reranker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

import asyncio
from typing import List, Tuple

from sentence_transformers import CrossEncoder

Expand All @@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
def __init__(self):
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')

async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
if not passages:
return []

Expand Down
7 changes: 3 additions & 4 deletions graphiti_core/cross_encoder/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

from abc import ABC, abstractmethod
from typing import List, Tuple


class CrossEncoderClient(ABC):
Expand All @@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
"""

@abstractmethod
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
"""
Rank the given passages based on their relevance to the query.
Args:
query (str): The query string.
passages (List[str]): A list of passages to rank.
passages (list[str]): A list of passages to rank.
Returns:
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
List[tuple[str, float]]: A list of tuples containing the passage and its score,
sorted in descending order of relevance.
"""
pass
4 changes: 2 additions & 2 deletions graphiti_core/embedder/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

from abc import ABC, abstractmethod
from typing import Iterable, List, Literal
from typing import Iterable, Literal

Check failure on line 18 in graphiti_core/embedder/client.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP035)

graphiti_core/embedder/client.py:18:1: UP035 Import from `collections.abc` instead: `Iterable`

from pydantic import BaseModel, Field

Expand All @@ -29,6 +29,6 @@ class EmbedderConfig(BaseModel):
class EmbedderClient(ABC):
@abstractmethod
async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
pass
4 changes: 2 additions & 2 deletions graphiti_core/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import Iterable, List
from typing import Iterable

Check failure on line 17 in graphiti_core/embedder/openai.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP035)

graphiti_core/embedder/openai.py:17:1: UP035 Import from `collections.abc` instead: `Iterable`

from openai import AsyncOpenAI
from openai.types import EmbeddingModel
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(self, config: OpenAIEmbedderConfig | None = None):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)

async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embeddings.create(
input=input_data, model=self.config.embedding_model
Expand Down
6 changes: 3 additions & 3 deletions graphiti_core/embedder/voyage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import Iterable, List
from typing import Iterable

Check failure on line 17 in graphiti_core/embedder/voyage.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP035)

graphiti_core/embedder/voyage.py:17:1: UP035 Import from `collections.abc` instead: `Iterable`

import voyageai # type: ignore
from pydantic import Field
Expand All @@ -41,11 +41,11 @@ def __init__(self, config: VoyageAIEmbedderConfig | None = None):
self.client = voyageai.AsyncClient(api_key=config.api_key)

async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
if isinstance(input_data, str):
input_list = [input_data]
elif isinstance(input_data, List):
elif isinstance(input_data, list):
input_list = [str(i) for i in input_data if i]
else:
input_list = [str(i) for i in input_data if i is not None]
Expand Down
13 changes: 6 additions & 7 deletions graphiti_core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
from datetime import datetime, timezone
from time import time
from typing import List

from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
Expand All @@ -34,11 +33,11 @@


def build_episodic_edges(
entity_nodes: List[EntityNode],
entity_nodes: list[EntityNode],
episode: EpisodicNode,
created_at: datetime,
) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = [
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = [
EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
Expand All @@ -52,11 +51,11 @@ def build_episodic_edges(


def build_community_edges(
entity_nodes: List[EntityNode],
entity_nodes: list[EntityNode],
community_node: CommunityNode,
created_at: datetime,
) -> List[CommunityEdge]:
edges: List[CommunityEdge] = [
) -> list[CommunityEdge]:
edges: list[CommunityEdge] = [
CommunityEdge(
source_node_uuid=community_node.uuid,
target_node_uuid=node.uuid,
Expand Down
3 changes: 1 addition & 2 deletions graphiti_core/utils/maintenance/temporal_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import logging
from datetime import datetime
from time import time
from typing import List

from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMClient
Expand All @@ -31,7 +30,7 @@ async def extract_edge_dates(
llm_client: LLMClient,
edge: EntityEdge,
current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode],
previous_episodes: list[EpisodicNode],
) -> tuple[datetime | None, datetime | None]:
context = {
'edge_fact': edge.fact,
Expand Down

0 comments on commit 4304c58

Please sign in to comment.