From 7d7bf6f7cc663ee65a5b900edbdca3e38af31118 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 28 Nov 2024 11:16:07 +0100 Subject: [PATCH] refactor: update components to access `ChatMessage.text` instead of `content` (#8589) * introduce text property and deprecate content * release note * use chatmessage.text * release note * linting --- haystack/components/builders/answer_builder.py | 7 ++++++- .../components/builders/chat_prompt_builder.py | 9 ++++++--- .../components/connectors/openapi_service.py | 8 +++++--- haystack/components/generators/openai.py | 5 +---- haystack/components/generators/openai_utils.py | 5 ++++- haystack/components/validators/json_schema.py | 17 +++++++---------- .../use-chatmessage-text-266c94d742c76d32.yaml | 5 +++++ .../builders/test_chat_prompt_builder.py | 8 ++++---- .../connectors/test_openapi_service.py | 4 ++-- test/components/generators/chat/test_azure.py | 2 +- .../generators/chat/test_hugging_face_local.py | 4 ++-- test/components/generators/chat/test_openai.py | 10 +++++----- test/components/validators/test_json_schema.py | 4 ++-- test/core/pipeline/features/test_run.py | 2 +- test/dataclasses/test_chat_message.py | 5 ++--- 15 files changed, 53 insertions(+), 42 deletions(-) create mode 100644 releasenotes/notes/use-chatmessage-text-266c94d742c76d32.yaml diff --git a/haystack/components/builders/answer_builder.py b/haystack/components/builders/answer_builder.py index b89d506aec..9a465cafc3 100644 --- a/haystack/components/builders/answer_builder.py +++ b/haystack/components/builders/answer_builder.py @@ -113,7 +113,12 @@ def run( # pylint: disable=too-many-positional-arguments all_answers = [] for reply, given_metadata in zip(replies, meta): # Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is - extracted_reply = reply.content if isinstance(reply, ChatMessage) else str(reply) + if isinstance(reply, ChatMessage): + if reply.text is None: + raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {reply}") + extracted_reply = reply.text + else: + extracted_reply = str(reply) extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {} extracted_metadata = {**extracted_metadata, **given_metadata} diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index d3be3f8059..fd9969f5b7 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -129,7 +129,9 @@ def __init__( for message in template: if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): # infer variables from template - ast = self._env.parse(message.content) + if message.text is None: + raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") + ast = self._env.parse(message.text) template_variables = meta.find_undeclared_variables(ast) variables += list(template_variables) self.variables = variables @@ -192,8 +194,9 @@ def run( for message in template: if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): self._validate_variables(set(template_variables_combined.keys())) - - compiled_template = self._env.from_string(message.content) + if message.text is None: + raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") + compiled_template = self._env.from_string(message.text) rendered_content = compiled_template.render(template_variables_combined) # deep copy the message to avoid modifying the original message rendered_message: ChatMessage = deepcopy(message) diff --git a/haystack/components/connectors/openapi_service.py b/haystack/components/connectors/openapi_service.py index 716a50124d..ea89dde54d 100644 --- a/haystack/components/connectors/openapi_service.py +++ b/haystack/components/connectors/openapi_service.py @@ -161,15 +161,17 @@ def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]: :raises ValueError: If the content is not valid JSON or lacks required fields. """ function_payloads = [] + if message.text is None: + raise ValueError(f"The provided ChatMessage has no text.\nChatMessage: {message}") try: - tool_calls = json.loads(message.content) + tool_calls = json.loads(message.text) except json.JSONDecodeError: - raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.content) + raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.text) for tool_call in tool_calls: # this should never happen, but just in case do a sanity check if "type" not in tool_call: - raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.content) + raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.text) # In OpenAPIServiceConnector we know how to handle functions tools only if tool_call["type"] == "function": diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index f296ec82ca..d50b082556 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -240,10 +240,7 @@ def run( for response in completions: self._check_finish_reason(response) - return { - "replies": [message.content for message in completions], - "meta": [message.meta for message in completions], - } + return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]} @staticmethod def _create_message_from_chunks( diff --git a/haystack/components/generators/openai_utils.py b/haystack/components/generators/openai_utils.py index 555e0ee30b..5b1838c386 100644 --- a/haystack/components/generators/openai_utils.py +++ b/haystack/components/generators/openai_utils.py @@ -18,7 +18,10 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: - `content` - `name` (optional) """ - openai_msg = {"role": message.role.value, "content": message.content} + if message.text is None: + raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") + + openai_msg = {"role": message.role.value, "content": message.text} if message.name: openai_msg["name"] = message.name diff --git a/haystack/components/validators/json_schema.py b/haystack/components/validators/json_schema.py index 13f35be87e..0a449aff42 100644 --- a/haystack/components/validators/json_schema.py +++ b/haystack/components/validators/json_schema.py @@ -141,18 +141,20 @@ def run( dictionaries. """ last_message = messages[-1] - if not is_valid_json(last_message.content): + if last_message.text is None: + raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {last_message}") + if not is_valid_json(last_message.text): return { "validation_error": [ ChatMessage.from_user( - f"The message '{last_message.content}' is not a valid JSON object. " + f"The message '{last_message.text}' is not a valid JSON object. " f"Please provide only a valid JSON object in string format." f"Don't use any markdown and don't add any comment." ) ] } - last_message_content = json.loads(last_message.content) + last_message_content = json.loads(last_message.text) json_schema = json_schema or self.json_schema error_template = error_template or self.error_template or self.default_error_template @@ -182,16 +184,11 @@ def run( error_template = error_template or self.default_error_template recovery_prompt = self._construct_error_recovery_message( - error_template, - str(e), - error_path, - error_schema_path, - validation_schema, - failing_json=last_message.content, + error_template, str(e), error_path, error_schema_path, validation_schema, failing_json=last_message.text ) return {"validation_error": [ChatMessage.from_user(recovery_prompt)]} - def _construct_error_recovery_message( + def _construct_error_recovery_message( # pylint: disable=too-many-positional-arguments self, error_template: str, error_message: str, diff --git a/releasenotes/notes/use-chatmessage-text-266c94d742c76d32.yaml b/releasenotes/notes/use-chatmessage-text-266c94d742c76d32.yaml new file mode 100644 index 0000000000..7c4062fc70 --- /dev/null +++ b/releasenotes/notes/use-chatmessage-text-266c94d742c76d32.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Replace usage of `ChatMessage.content` with `ChatMessage.text` across the codebase. + This is done in preparation for the removal of `content` in Haystack 2.9.0. diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index 7e42eedbcb..5e1ae6132e 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -18,8 +18,8 @@ def test_init(self): ] ) assert builder.required_variables == [] - assert builder.template[0].content == "This is a {{ variable }}" - assert builder.template[1].content == "This is a {{ variable2 }}" + assert builder.template[0].text == "This is a {{ variable }}" + assert builder.template[1].text == "This is a {{ variable2 }}" assert builder._variables is None assert builder._required_variables is None @@ -62,7 +62,7 @@ def test_init_with_required_variables(self): template=[ChatMessage.from_user("This is a {{ variable }}")], required_variables=["variable"] ) assert builder.required_variables == ["variable"] - assert builder.template[0].content == "This is a {{ variable }}" + assert builder.template[0].text == "This is a {{ variable }}" assert builder._variables is None assert builder._required_variables == ["variable"] @@ -84,7 +84,7 @@ def test_init_with_custom_variables(self): builder = ChatPromptBuilder(template=template, variables=variables) assert builder.required_variables == [] assert builder._variables == variables - assert builder.template[0].content == "Hello, {{ var1 }}, {{ var2 }}!" + assert builder.template[0].text == "Hello, {{ var1 }}, {{ var2 }}!" assert builder._required_variables is None # we have inputs that contain: template, template_variables + variables diff --git a/test/components/connectors/test_openapi_service.py b/test/components/connectors/test_openapi_service.py index b5681f3218..4e488012c4 100644 --- a/test/components/connectors/test_openapi_service.py +++ b/test/components/connectors/test_openapi_service.py @@ -182,7 +182,7 @@ def test_run_with_mix_params_request_body(self, openapi_mock, test_files_path): # verify call went through on the wire mock_service.call_greet.assert_called_once_with(parameters={"name": "John"}, data={"message": "Hello"}) - response = json.loads(result["service_response"][0].content) + response = json.loads(result["service_response"][0].text) assert response == "Hello, John" @patch("haystack.components.connectors.openapi_service.OpenAPI") @@ -259,7 +259,7 @@ def test_run_with_complex_types(self, openapi_mock, test_files_path): } ) - response = json.loads(result["service_response"][0].content) + response = json.loads(result["service_response"][0].text) assert response == {"result": "accepted"} @patch("haystack.components.connectors.openapi_service.OpenAPI") diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index fdafec682d..c104d0e725 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -113,7 +113,7 @@ def test_live_run(self): results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert "gpt-4o-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 3f4dbdd06a..433917ec23 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -192,7 +192,7 @@ def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): assert isinstance(results["replies"][0], ChatMessage) chat_message = results["replies"][0] assert chat_message.is_from(ChatRole.ASSISTANT) - assert chat_message.content == "Berlin is cool" + assert chat_message.text == "Berlin is cool" def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") @@ -216,4 +216,4 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel assert isinstance(results["replies"][0], ChatMessage) chat_message = results["replies"][0] assert chat_message.is_from(ChatRole.ASSISTANT) - assert chat_message.content == "Berlin is cool" + assert chat_message.text == "Berlin is cool" diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 9c6fe1db5d..0461ba3cde 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -218,7 +218,7 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk + assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk @patch("haystack.components.generators.chat.openai.datetime") def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk): @@ -240,7 +240,7 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk + assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk assert hasattr(response["replies"][0], "meta") assert isinstance(response["replies"][0].meta, dict) @@ -287,7 +287,7 @@ def test_live_run(self): results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert "gpt-4o-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" @@ -322,7 +322,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert "gpt-4o-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" @@ -353,7 +353,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert "gpt-4o-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" diff --git a/test/components/validators/test_json_schema.py b/test/components/validators/test_json_schema.py index 39e9c78587..a407d7c103 100644 --- a/test/components/validators/test_json_schema.py +++ b/test/components/validators/test_json_schema.py @@ -184,7 +184,7 @@ def run(self): result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}}) assert "validated" in result["schema_validator"] assert len(result["schema_validator"]["validated"]) == 1 - assert result["schema_validator"]["validated"][0].content == genuine_fc_message + assert result["schema_validator"]["validated"][0].text == genuine_fc_message def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare): @component @@ -202,4 +202,4 @@ def run(self): result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}}) assert "validation_error" in result["schema_validator"] assert len(result["schema_validator"]["validation_error"]) == 1 - assert "Error details" in result["schema_validator"]["validation_error"][0].content + assert "Error details" in result["schema_validator"]["validation_error"][0].text diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index 6a82cf4dbd..d7001a0187 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -603,7 +603,7 @@ def run(self, prompt_source: List[ChatMessage]): class MessageMerger: @component.output_types(merged_message=str) def run(self, messages: List[ChatMessage], metadata: dict = None): - return {"merged_message": "\n".join(t.content for t in messages)} + return {"merged_message": "\n".join(t.text or "" for t in messages)} @component class FakeGenerator: diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index cffd4c94da..30ad51630e 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -114,9 +114,8 @@ def test_from_dict(): def test_from_dict_with_meta(): - assert ChatMessage.from_dict( - data={"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}} - ) == ChatMessage.from_assistant("text", meta={"something": "something"}) + data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}} + assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"}) def test_content_deprecation_warning(recwarn):