diff --git a/graphiti_core/llm_client/openai_client.py b/graphiti_core/llm_client/openai_client.py index c92b4fb3..d8b02e8b 100644 --- a/graphiti_core/llm_client/openai_client.py +++ b/graphiti_core/llm_client/openai_client.py @@ -16,6 +16,7 @@ import logging import typing +from typing import ClassVar import openai from openai import AsyncOpenAI @@ -53,6 +54,9 @@ class OpenAIClient(LLMClient): Generates a response from the language model based on the provided messages. """ + # Class-level constants + MAX_RETRIES: ClassVar[int] = 2 + def __init__( self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None ): @@ -104,7 +108,7 @@ async def _generate_response( elif response_object.refusal: raise RefusalError(response_object.refusal) else: - raise Exception('No response from LLM') + raise Exception(f'Invalid response from LLM: {response_object.model_dump()}') except openai.LengthFinishReasonError as e: raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e except openai.RateLimitError as e: @@ -116,6 +120,43 @@ async def _generate_response( async def generate_response( self, messages: list[Message], response_model: type[BaseModel] | None = None ) -> dict[str, typing.Any]: - response = await self._generate_response(messages, response_model) - - return response + retry_count = 0 + last_error = None + + while retry_count <= self.MAX_RETRIES: + try: + response = await self._generate_response(messages, response_model) + return response + except (RateLimitError, RefusalError): + # These errors should not trigger retries + raise + except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError): + # Let OpenAI's client handle these retries + raise + except Exception as e: + last_error = e + + # Don't retry if we've hit the max retries + if retry_count >= self.MAX_RETRIES: + logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}') + raise + + retry_count += 1 + + # Construct a detailed error message for the LLM + error_context = ( + f'The previous response attempt was invalid. ' + f'Error type: {e.__class__.__name__}. ' + f'Error details: {str(e)}. ' + f'Please try again with a valid response, ensuring the output matches ' + f'the expected format and constraints.' + ) + + error_message = Message(role='user', content=error_context) + messages.append(error_message) + logger.warning( + f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}' + ) + + # If we somehow get here, raise the last error + raise last_error or Exception('Max retries exceeded with no specific error')