Skip to content

Commit

Permalink
[ci] add gemma test
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Jul 15, 2024
1 parent a71e77a commit 66fdd4c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/.test_durations
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"tests/utilization/dataset/test_formatting.py::test_multi_turn": 5.1760842725634575,
"tests/utilization/model/test_apply_prompt_template.py::test_base": 0.0028072865679860115,
"tests/utilization/model/test_apply_prompt_template.py::test_final_strip": 0.0016121016815304756,
"tests/utilization/model/test_apply_prompt_template.py::test_gemma": 0.009262016043066978,
"tests/utilization/model/test_apply_prompt_template.py::test_llama2": 0.0018699830397963524,
"tests/utilization/model/test_apply_prompt_template.py::test_no_smart_space": 0.0017823278903961182,
"tests/utilization/model/test_apply_prompt_template.py::test_phi3": 0.009263357031159103,
Expand Down
17 changes: 17 additions & 0 deletions tests/utilization/model/test_apply_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def test_phi3(conversation: Conversation):
)


def test_gemma(conversation: Conversation):
formatter = ConversationFormatter.from_chat_template("gemma")
conversation.set_formatter(formatter)
formatted_conversation = conversation.apply_prompt_template()
assert formatted_conversation == (
"<bos><start_of_turn>user\n"
"This is a system message.\n"
"This is a user message.<end_of_turn>\n"
"<start_of_turn>model\n"
"This is an assistant message.<end_of_turn>\n"
"<start_of_turn>user\n"
"This is the second user message.<end_of_turn>\n"
"<start_of_turn>model\n"
"This is the second assistant message.<end_of_turn>\n"
)


def test_no_smart_space(conversation: Conversation):
prompt_config = {
"system_start": "",
Expand Down
5 changes: 3 additions & 2 deletions utilization/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,11 @@ def smart_space(parts: List[Tuple[str, bool]], auto_leading_space: bool, no_spac
"gemma": {
"all_start": "<bos>",
"merge_system_to_user": True,
"system_user_sep": "\n",
"user_start": "<start_of_turn>user\n",
"user_end": "<end_of_turn>",
"user_end": "<end_of_turn>\n",
"assistant_start": "<start_of_turn>model\n",
"assistant_end": "<end_of_turn>",
"assistant_end": "<end_of_turn>\n",
"auto_leading_space": True,
"final_rstrip": False,
"no_space_between": True,
Expand Down
6 changes: 5 additions & 1 deletion utilization/model/model_utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
self.final_lstrip = chat_config.pop("final_lstrip", True)
self.final_rstrip = chat_config.pop("final_rstrip", True)
self.merge_system_to_user = chat_config.pop("merge_system_to_user", False)
self.system_user_sep = chat_config.pop("system_user_sep", "\n")

# api model does not need bos_token
if "bos_token" not in chat_config:
Expand Down Expand Up @@ -379,8 +380,10 @@ def get_generation_results(self) -> Union[str, Tuple[str, ...]]:
def _merge_system_to_user(self):
"""Whether to convert system message to part of next user message."""
if self.merge_system_to_user and self.messages[0]["role"] == "system":
self.messages[1]["content"] = self.messages[0]["content"] + self.messages[1]["content"]
msg = self.messages[0]["content"] + self.system_user_sep + self.messages[1]["content"]

self.messages.pop(0)
self.messages[0]["content"] = msg

def set_formatter(
self,
Expand All @@ -392,6 +395,7 @@ def set_formatter(
self.model_evaluation_method = model_evaluation_method
self.split = split and self.get_segs_num() > 1
self.merge_system_to_user = self.formatter.merge_system_to_user
self.system_user_sep = self.formatter.system_user_sep
self._merge_system_to_user()

def to_model_prompt(
Expand Down

0 comments on commit 66fdd4c

Please sign in to comment.