diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index 22cc3795..f2a19208 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -67,6 +67,28 @@ def __init__(self, config: LLMConfig | None, cache: bool = False): else None, reraise=True, ) + def _clean_input(self, input: str) -> str: + """Clean input string of invalid unicode and control characters. + + Args: + input: Raw input string to be cleaned + + Returns: + Cleaned string safe for LLM processing + """ + # Clean any invalid Unicode + cleaned = input.encode('utf-8', errors='ignore').decode('utf-8') + + # Remove zero-width characters and other invisible unicode + zero_width = '\u200b\u200c\u200d\ufeff\u2060' + for char in zero_width: + cleaned = cleaned.replace(char, '') + + # Remove control characters except newlines, returns, and tabs + cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t') + + return cleaned + async def _generate_response_with_retry( self, messages: list[Message], response_model: type[BaseModel] | None = None ) -> dict[str, typing.Any]: @@ -106,6 +128,9 @@ async def generate_response( logger.debug(f'Cache hit for {cache_key}') return cached_response + for message in messages: + message.content = self._clean_input(message.content) + response = await self._generate_response_with_retry(messages, response_model) if self.cache_enabled: diff --git a/graphiti_core/llm_client/openai_client.py b/graphiti_core/llm_client/openai_client.py index d8b02e8b..7804e06f 100644 --- a/graphiti_core/llm_client/openai_client.py +++ b/graphiti_core/llm_client/openai_client.py @@ -88,6 +88,7 @@ async def _generate_response( ) -> dict[str, typing.Any]: openai_messages: list[ChatCompletionMessageParam] = [] for m in messages: + m.content = self._clean_input(m.content) if m.role == 'user': openai_messages.append({'role': 'user', 'content': m.content}) elif m.role == 'system': diff --git a/tests/llm_client/test_client.py b/tests/llm_client/test_client.py new file mode 100644 index 00000000..4a2fbd7c --- /dev/null +++ b/tests/llm_client/test_client.py @@ -0,0 +1,41 @@ +from graphiti_core.llm_client.client import LLMClient +from graphiti_core.llm_client.config import LLMConfig + + +class TestLLMClient(LLMClient): + """Concrete implementation of LLMClient for testing""" + + async def _generate_response(self, messages, response_model=None): + return {'content': 'test'} + + +def test_clean_input(): + client = TestLLMClient(LLMConfig()) + + test_cases = [ + # Basic text should remain unchanged + ('Hello World', 'Hello World'), + # Control characters should be removed + ('Hello\x00World', 'HelloWorld'), + # Newlines, tabs, returns should be preserved + ('Hello\nWorld\tTest\r', 'Hello\nWorld\tTest\r'), + # Invalid Unicode should be removed + ('Hello\udcdeWorld', 'HelloWorld'), + # Zero-width characters should be removed + ('Hello\u200bWorld', 'HelloWorld'), + ('Test\ufeffWord', 'TestWord'), + # Multiple issues combined + ('Hello\x00\u200b\nWorld\udcde', 'Hello\nWorld'), + # Empty string should remain empty + ('', ''), + # Form feed and other control characters from the error case + ('{"edges":[{"relation_typ...\f\x04Hn\\?"}]}', '{"edges":[{"relation_typ...Hn\\?"}]}'), + # More specific control character tests + ('Hello\x0cWorld', 'HelloWorld'), # form feed \f + ('Hello\x04World', 'HelloWorld'), # end of transmission + # Combined JSON-like string with control characters + ('{"test": "value\f\x00\x04"}', '{"test": "value"}'), + ] + + for input_str, expected in test_cases: + assert client._clean_input(input_str) == expected, f'Failed for input: {repr(input_str)}'