Skip to content

Commit

Permalink
refactor: update components to access ChatMessage.text instead of `…
Browse files Browse the repository at this point in the history
…content` (#8589)

* introduce text property and deprecate content

* release note

* use chatmessage.text

* release note

* linting
  • Loading branch information
anakin87 authored and Amnah199 committed Dec 3, 2024
1 parent 18de0a1 commit 7d7bf6f
Show file tree
Hide file tree
Showing 15 changed files with 53 additions and 42 deletions.
7 changes: 6 additions & 1 deletion haystack/components/builders/answer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,12 @@ def run( # pylint: disable=too-many-positional-arguments
all_answers = []
for reply, given_metadata in zip(replies, meta):
# Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is
extracted_reply = reply.content if isinstance(reply, ChatMessage) else str(reply)
if isinstance(reply, ChatMessage):
if reply.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {reply}")
extracted_reply = reply.text
else:
extracted_reply = str(reply)
extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {}

extracted_metadata = {**extracted_metadata, **given_metadata}
Expand Down
9 changes: 6 additions & 3 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def __init__(
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infer variables from template
ast = self._env.parse(message.content)
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
self.variables = variables
Expand Down Expand Up @@ -192,8 +194,9 @@ def run(
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
self._validate_variables(set(template_variables_combined.keys()))

compiled_template = self._env.from_string(message.content)
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
compiled_template = self._env.from_string(message.text)
rendered_content = compiled_template.render(template_variables_combined)
# deep copy the message to avoid modifying the original message
rendered_message: ChatMessage = deepcopy(message)
Expand Down
8 changes: 5 additions & 3 deletions haystack/components/connectors/openapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,17 @@ def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
:raises ValueError: If the content is not valid JSON or lacks required fields.
"""
function_payloads = []
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text.\nChatMessage: {message}")
try:
tool_calls = json.loads(message.content)
tool_calls = json.loads(message.text)
except json.JSONDecodeError:
raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.content)
raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.text)

for tool_call in tool_calls:
# this should never happen, but just in case do a sanity check
if "type" not in tool_call:
raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.content)
raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.text)

# In OpenAPIServiceConnector we know how to handle functions tools only
if tool_call["type"] == "function":
Expand Down
5 changes: 1 addition & 4 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def run(
for response in completions:
self._check_finish_reason(response)

return {
"replies": [message.content for message in completions],
"meta": [message.meta for message in completions],
}
return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]}

@staticmethod
def _create_message_from_chunks(
Expand Down
5 changes: 4 additions & 1 deletion haystack/components/generators/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]:
- `content`
- `name` (optional)
"""
openai_msg = {"role": message.role.value, "content": message.content}
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")

openai_msg = {"role": message.role.value, "content": message.text}
if message.name:
openai_msg["name"] = message.name

Expand Down
17 changes: 7 additions & 10 deletions haystack/components/validators/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,20 @@ def run(
dictionaries.
"""
last_message = messages[-1]
if not is_valid_json(last_message.content):
if last_message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {last_message}")
if not is_valid_json(last_message.text):
return {
"validation_error": [
ChatMessage.from_user(
f"The message '{last_message.content}' is not a valid JSON object. "
f"The message '{last_message.text}' is not a valid JSON object. "
f"Please provide only a valid JSON object in string format."
f"Don't use any markdown and don't add any comment."
)
]
}

last_message_content = json.loads(last_message.content)
last_message_content = json.loads(last_message.text)
json_schema = json_schema or self.json_schema
error_template = error_template or self.error_template or self.default_error_template

Expand Down Expand Up @@ -182,16 +184,11 @@ def run(
error_template = error_template or self.default_error_template

recovery_prompt = self._construct_error_recovery_message(
error_template,
str(e),
error_path,
error_schema_path,
validation_schema,
failing_json=last_message.content,
error_template, str(e), error_path, error_schema_path, validation_schema, failing_json=last_message.text
)
return {"validation_error": [ChatMessage.from_user(recovery_prompt)]}

def _construct_error_recovery_message(
def _construct_error_recovery_message( # pylint: disable=too-many-positional-arguments
self,
error_template: str,
error_message: str,
Expand Down
5 changes: 5 additions & 0 deletions releasenotes/notes/use-chatmessage-text-266c94d742c76d32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Replace usage of `ChatMessage.content` with `ChatMessage.text` across the codebase.
This is done in preparation for the removal of `content` in Haystack 2.9.0.
8 changes: 4 additions & 4 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def test_init(self):
]
)
assert builder.required_variables == []
assert builder.template[0].content == "This is a {{ variable }}"
assert builder.template[1].content == "This is a {{ variable2 }}"
assert builder.template[0].text == "This is a {{ variable }}"
assert builder.template[1].text == "This is a {{ variable2 }}"
assert builder._variables is None
assert builder._required_variables is None

Expand Down Expand Up @@ -62,7 +62,7 @@ def test_init_with_required_variables(self):
template=[ChatMessage.from_user("This is a {{ variable }}")], required_variables=["variable"]
)
assert builder.required_variables == ["variable"]
assert builder.template[0].content == "This is a {{ variable }}"
assert builder.template[0].text == "This is a {{ variable }}"
assert builder._variables is None
assert builder._required_variables == ["variable"]

Expand All @@ -84,7 +84,7 @@ def test_init_with_custom_variables(self):
builder = ChatPromptBuilder(template=template, variables=variables)
assert builder.required_variables == []
assert builder._variables == variables
assert builder.template[0].content == "Hello, {{ var1 }}, {{ var2 }}!"
assert builder.template[0].text == "Hello, {{ var1 }}, {{ var2 }}!"
assert builder._required_variables is None

# we have inputs that contain: template, template_variables + variables
Expand Down
4 changes: 2 additions & 2 deletions test/components/connectors/test_openapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_run_with_mix_params_request_body(self, openapi_mock, test_files_path):
# verify call went through on the wire
mock_service.call_greet.assert_called_once_with(parameters={"name": "John"}, data={"message": "Hello"})

response = json.loads(result["service_response"][0].content)
response = json.loads(result["service_response"][0].text)
assert response == "Hello, John"

@patch("haystack.components.connectors.openapi_service.OpenAPI")
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_run_with_complex_types(self, openapi_mock, test_files_path):
}
)

response = json.loads(result["service_response"][0].content)
response = json.loads(result["service_response"][0].text)
assert response == {"result": "accepted"}

@patch("haystack.components.connectors.openapi_service.OpenAPI")
Expand Down
2 changes: 1 addition & 1 deletion test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_live_run(self):
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text
assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

Expand Down
4 changes: 2 additions & 2 deletions test/components/generators/chat/test_hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
assert isinstance(results["replies"][0], ChatMessage)
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.content == "Berlin is cool"
assert chat_message.text == "Berlin is cool"

def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
Expand All @@ -216,4 +216,4 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel
assert isinstance(results["replies"][0], ChatMessage)
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.content == "Berlin is cool"
assert chat_message.text == "Berlin is cool"
10 changes: 5 additions & 5 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk

@patch("haystack.components.generators.chat.openai.datetime")
def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk):
Expand All @@ -240,7 +240,7 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk

assert hasattr(response["replies"][0], "meta")
assert isinstance(response["replies"][0].meta, dict)
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_live_run(self):
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text
assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

Expand Down Expand Up @@ -322,7 +322,7 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text

assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"
Expand Down Expand Up @@ -353,7 +353,7 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text

assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"
Expand Down
4 changes: 2 additions & 2 deletions test/components/validators/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def run(self):
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validated" in result["schema_validator"]
assert len(result["schema_validator"]["validated"]) == 1
assert result["schema_validator"]["validated"][0].content == genuine_fc_message
assert result["schema_validator"]["validated"][0].text == genuine_fc_message

def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare):
@component
Expand All @@ -202,4 +202,4 @@ def run(self):
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validation_error" in result["schema_validator"]
assert len(result["schema_validator"]["validation_error"]) == 1
assert "Error details" in result["schema_validator"]["validation_error"][0].content
assert "Error details" in result["schema_validator"]["validation_error"][0].text
2 changes: 1 addition & 1 deletion test/core/pipeline/features/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def run(self, prompt_source: List[ChatMessage]):
class MessageMerger:
@component.output_types(merged_message=str)
def run(self, messages: List[ChatMessage], metadata: dict = None):
return {"merged_message": "\n".join(t.content for t in messages)}
return {"merged_message": "\n".join(t.text or "" for t in messages)}

@component
class FakeGenerator:
Expand Down
5 changes: 2 additions & 3 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ def test_from_dict():


def test_from_dict_with_meta():
assert ChatMessage.from_dict(
data={"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}}
) == ChatMessage.from_assistant("text", meta={"something": "something"})
data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}}
assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"})


def test_content_deprecation_warning(recwarn):
Expand Down

0 comments on commit 7d7bf6f

Please sign in to comment.