diff --git a/utilization/chat_templates.py b/utilization/chat_templates.py index 4f74cf7e..af454228 100644 --- a/utilization/chat_templates.py +++ b/utilization/chat_templates.py @@ -7,11 +7,11 @@ def add_space( msg: str, context: str, auto_leading_space: bool = True, - remove_space_between: bool = True, + no_space_between: bool = True, starts: Optional[List[str]] = None, ends: Optional[List[str]] = None ) -> str: - if starts is None or ends is None or remove_space_between is False: + if starts is None or ends is None or no_space_between is False: context_ends_special = False msg_starts_special = False else: @@ -24,9 +24,7 @@ def add_space( return msg -def smart_space( - parts: List[Tuple[str, bool]], auto_leading_space: bool, remove_space_between: bool, seq: List[str] -) -> str: +def smart_space(parts: List[Tuple[str, bool]], auto_leading_space: bool, no_space_between: bool, seq: List[str]) -> str: starts = [seq[role + "_start"] for role in ["system", "user", "assistant"] if (role + "_start") in seq] ends = [seq[role + "_end"] for role in ["system", "user", "assistant"] if (role + "_end") in seq] if "bos_token" in seq: @@ -38,7 +36,7 @@ def smart_space( part[0], rendered, auto_leading_space=auto_leading_space and part[1], - remove_space_between=remove_space_between, + no_space_between=no_space_between, starts=starts, ends=ends ) @@ -65,7 +63,7 @@ def smart_space( "{%- set data.parts = data.parts + [(seq['generation_prompt'], True)] -%}" "{%- endif -%}" "" - "{{ data.parts | smart_space(auto_leading_space, remove_space_between, seq) }}" + "{{ data.parts | smart_space(auto_leading_space, no_space_between, seq) }}" ) # Chat configs format: @@ -86,7 +84,9 @@ def smart_space( # - assistant_end: The string to append to the assistant message. # - auto_leading_space: Whether to add a leading space when concatenating two # strings if the first string does not end with a whitespace. +# - no_space_between: Whether to not add the leading space between special tokens. # - default_stop: A list of strings that indicate the end of a message. +# - merge_system_to_user: Whether to convert system message to part of next user message. # DEFAULT_CHAT_CONFIGS: Dict[str, Union[Dict[str, Any], str]] = { "base": { @@ -98,7 +98,7 @@ def smart_space( "assistant_end": "\n\n", "auto_leading_space": True, "final_rstrip": True, - "remove_space_between": False, + "no_space_between": False, "default_stop": [], }, "llama2": { @@ -110,7 +110,7 @@ def smart_space( "assistant_end": " ", "auto_leading_space": True, "final_rstrip": False, - "remove_space_between": True, + "no_space_between": True, "default_stop": [], }, "chatml": { @@ -122,7 +122,7 @@ def smart_space( "assistant_end": "<|im_end|>\n", "auto_leading_space": True, "final_rstrip": False, - "remove_space_between": True, + "no_space_between": True, "default_stop": ["<|im_end|>"], }, "zephyr": { @@ -134,7 +134,7 @@ def smart_space( "assistant_end": "\n", "auto_leading_space": True, "final_rstrip": False, - "remove_space_between": True, + "no_space_between": True, "default_stop": [""], }, "phi3": { @@ -146,7 +146,7 @@ def smart_space( "assistant_end": "<|end|>\n", "auto_leading_space": True, "final_rstrip": False, - "remove_space_between": True, + "no_space_between": True, "default_stop": ["<|end|>", "<|endoftext|>"], }, "llama3": { @@ -158,7 +158,7 @@ def smart_space( "assistant_end": "<|eot_id|>", "auto_leading_space": True, "final_rstrip": False, - "remove_space_between": True, + "no_space_between": True, "default_stop": ["<|eot_id|>"], }, "alpaca": { @@ -170,7 +170,19 @@ def smart_space( "assistant_end": "\n\n", "auto_leading_space": True, "final_rstrip": False, - "remove_space_between": False, + "no_space_between": False, "default_stop": ["###"], + }, + "gemma": { + "all_start": "", + "merge_system_to_user": True, + "user_start": "user\n", + "user_end": "", + "assistant_start": "model\n", + "assistant_end": "", + "auto_leading_space": True, + "final_rstrip": False, + "no_space_between": True, + "default_stop": ["", ""], } } diff --git a/utilization/model/model_utils/conversation.py b/utilization/model/model_utils/conversation.py index f1fa4651..919f3f5d 100644 --- a/utilization/model/model_utils/conversation.py +++ b/utilization/model/model_utils/conversation.py @@ -56,6 +56,7 @@ def __init__( self.auto_leading_space = chat_config.pop("auto_leading_space", True) self.final_lstrip = chat_config.pop("final_lstrip", True) self.final_rstrip = chat_config.pop("final_rstrip", True) + self.merge_system_to_user = chat_config.pop("merge_system_to_user", False) # api model does not need bos_token if "bos_token" not in chat_config: @@ -375,6 +376,12 @@ def get_generation_results(self) -> Union[str, Tuple[str, ...]]: assert self.messages[-1]["role"] == "assistant" return self.messages[-1]["content"] + def _merge_system_to_user(self): + """Whether to convert system message to part of next user message.""" + if self.merge_system_to_user and self.messages[0]["role"] == "system": + self.messages[1]["content"] = self.messages[0]["content"] + self.messages[1]["content"] + self.messages.pop(0) + def set_formatter( self, formatter: ConversationFormatter, @@ -384,6 +391,8 @@ def set_formatter( self.formatter = formatter self.model_evaluation_method = model_evaluation_method self.split = split and self.get_segs_num() > 1 + self.merge_system_to_user = self.formatter.merge_system_to_user + self._merge_system_to_user() def to_model_prompt( self, diff --git a/utilization/utils/arguments.py b/utilization/utils/arguments.py index 09a770ca..12cc0c85 100644 --- a/utilization/utils/arguments.py +++ b/utilization/utils/arguments.py @@ -336,7 +336,8 @@ def __post_init__(self): if self.model_name_or_path in API_MODELS: auto_model_type = API_MODELS[self.model_name_or_path]["model_type"] elif self.is_local_model(): - auto_model_type = "chat" if re.search(r"chat|instruct", self.model_name_or_path.lower()) else "base" + # gemma uses it: instruction-tuned + auto_model_type = "chat" if re.search(r"chat|instruct|it", self.model_name_or_path.lower()) else "base" else: auto_model_type = None