From 66fdd4cfae59cd62ff67944b8193485d2245d7d8 Mon Sep 17 00:00:00 2001 From: huyiwen <1020030101@qq.com> Date: Mon, 15 Jul 2024 21:28:50 +0800 Subject: [PATCH] [ci] add gemma test --- .github/.test_durations | 1 + .../model/test_apply_prompt_template.py | 17 +++++++++++++++++ utilization/chat_templates.py | 5 +++-- utilization/model/model_utils/conversation.py | 6 +++++- 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/.github/.test_durations b/.github/.test_durations index b829824f..d3fe680f 100644 --- a/.github/.test_durations +++ b/.github/.test_durations @@ -93,6 +93,7 @@ "tests/utilization/dataset/test_formatting.py::test_multi_turn": 5.1760842725634575, "tests/utilization/model/test_apply_prompt_template.py::test_base": 0.0028072865679860115, "tests/utilization/model/test_apply_prompt_template.py::test_final_strip": 0.0016121016815304756, + "tests/utilization/model/test_apply_prompt_template.py::test_gemma": 0.009262016043066978, "tests/utilization/model/test_apply_prompt_template.py::test_llama2": 0.0018699830397963524, "tests/utilization/model/test_apply_prompt_template.py::test_no_smart_space": 0.0017823278903961182, "tests/utilization/model/test_apply_prompt_template.py::test_phi3": 0.009263357031159103, diff --git a/tests/utilization/model/test_apply_prompt_template.py b/tests/utilization/model/test_apply_prompt_template.py index 960a3484..9825de9f 100644 --- a/tests/utilization/model/test_apply_prompt_template.py +++ b/tests/utilization/model/test_apply_prompt_template.py @@ -72,6 +72,23 @@ def test_phi3(conversation: Conversation): ) +def test_gemma(conversation: Conversation): + formatter = ConversationFormatter.from_chat_template("gemma") + conversation.set_formatter(formatter) + formatted_conversation = conversation.apply_prompt_template() + assert formatted_conversation == ( + "user\n" + "This is a system message.\n" + "This is a user message.\n" + "model\n" + "This is an assistant message.\n" + "user\n" + "This is the second user message.\n" + "model\n" + "This is the second assistant message.\n" + ) + + def test_no_smart_space(conversation: Conversation): prompt_config = { "system_start": "", diff --git a/utilization/chat_templates.py b/utilization/chat_templates.py index af454228..cf94cecc 100644 --- a/utilization/chat_templates.py +++ b/utilization/chat_templates.py @@ -176,10 +176,11 @@ def smart_space(parts: List[Tuple[str, bool]], auto_leading_space: bool, no_spac "gemma": { "all_start": "", "merge_system_to_user": True, + "system_user_sep": "\n", "user_start": "user\n", - "user_end": "", + "user_end": "\n", "assistant_start": "model\n", - "assistant_end": "", + "assistant_end": "\n", "auto_leading_space": True, "final_rstrip": False, "no_space_between": True, diff --git a/utilization/model/model_utils/conversation.py b/utilization/model/model_utils/conversation.py index 919f3f5d..2047d32a 100644 --- a/utilization/model/model_utils/conversation.py +++ b/utilization/model/model_utils/conversation.py @@ -57,6 +57,7 @@ def __init__( self.final_lstrip = chat_config.pop("final_lstrip", True) self.final_rstrip = chat_config.pop("final_rstrip", True) self.merge_system_to_user = chat_config.pop("merge_system_to_user", False) + self.system_user_sep = chat_config.pop("system_user_sep", "\n") # api model does not need bos_token if "bos_token" not in chat_config: @@ -379,8 +380,10 @@ def get_generation_results(self) -> Union[str, Tuple[str, ...]]: def _merge_system_to_user(self): """Whether to convert system message to part of next user message.""" if self.merge_system_to_user and self.messages[0]["role"] == "system": - self.messages[1]["content"] = self.messages[0]["content"] + self.messages[1]["content"] + msg = self.messages[0]["content"] + self.system_user_sep + self.messages[1]["content"] + self.messages.pop(0) + self.messages[0]["content"] = msg def set_formatter( self, @@ -392,6 +395,7 @@ def set_formatter( self.model_evaluation_method = model_evaluation_method self.split = split and self.get_segs_num() > 1 self.merge_system_to_user = self.formatter.merge_system_to_user + self.system_user_sep = self.formatter.system_user_sep self._merge_system_to_user() def to_model_prompt(