Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add 'right' option for 'truncation_strategy' #2754

Merged
merged 1 commit into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
- 🔥template: 对话模板类型,默认使用model对应的template类型。`swift pt`会将对话模版转为生成模板使用
- 🔥system: 自定义system字段,默认为None,使用template的默认system
- 🔥max_length: 单样本的tokens最大长度,默认为None,不做限制
- truncation_strategy: 如果超长如何处理,支持`delete``left`,代表删除和左侧裁剪,默认为'delete'
- truncation_strategy: 如果超长如何处理,支持`delete`, `left`和`right`,代表删除、左侧裁剪和右侧裁剪,默认为'delete'
- 🔥max_pixels: 多模态模型图片前处理的最大像素数(H\*W),默认不缩放。
- tools_prompt: 智能体训练时的工具列表转为system的格式,请参考[智能体训练](./智能体的支持.md),默认为'react_en'
- loss_scale: 如何针对训练添加token的loss权重。默认为`'default'`,代表所有response(含history)以1计算交叉熵损失。具体可以查看[插件化](../Customization/插件化.md)和[智能体训练](./智能体的支持.md)
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The introduction to command line parameters will cover base arguments, atomic ar
- 🔥template: Type of dialogue template, which defaults to the template type corresponding to the model. `swift pt` will convert the dialogue template into a generation template for use.
- 🔥system: Custom system field, default is None, uses the default system of the template.
- 🔥max_length: Maximum length of tokens for a single sample, default is None (no limit).
- truncation_strategy: How to handle overly long tokens, supports `delete` and `left`, representing deletion and left trimming, default is 'delete'.
- truncation_strategy: How to handle overly long tokens, supports `delete`, `left`, `right`, representing deletion, left trimming, and right trimming, default is 'delete'.
- 🔥max_pixels: Maximum pixel count for pre-processing images in multimodal models (H*W), default is no scaling.
- tools_prompt: The list of tools for agent training converted to system format, refer to [Agent Training](./Agent-support.md), default is 'react_en'.
- loss_scale: How to add token loss weight during training. Default is `'default'`, meaning all responses (including history) are treated as 1 for cross-entropy loss. For specifics, see [Pluginization](../Customization/Pluginization.md) and [Agent Training](./Agent-support.md).
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/base_args/template_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TemplateArguments:
system: Optional[str] = None # Override the default_system in the template.
max_length: Optional[int] = None

truncation_strategy: Literal['delete', 'left'] = 'delete'
truncation_strategy: Literal['delete', 'left', 'right'] = 'delete'
max_pixels: Optional[int] = None
tools_prompt: str = 'react_en' # Override the default_tools_prompt in the template.
# train
Expand Down
19 changes: 13 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
*,
use_chat_template: bool = True,
template_backend: Literal['swift', 'jinja'] = 'swift',
truncation_strategy: Literal['raise', 'left'] = 'raise',
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
max_pixels: Optional[int] = None,
tools_prompt: Optional[str] = None,
# only for train
Expand Down Expand Up @@ -630,11 +630,18 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
if self.truncation_strategy == 'raise' and len(input_ids) > self.max_length:
raise MaxLengthError(f'Current length of row({len(input_ids)}) is larger'
f' than the max_length({self.max_length}).')
input_ids = input_ids[-self.max_length:]
if labels is not None:
labels = labels[-self.max_length:]
if loss_scale is not None:
loss_scale = loss_scale[-self.max_length:]
elif self.truncation_strategy == 'right':
input_ids = input_ids[:self.max_length]
if labels is not None:
labels = labels[:self.max_length]
if loss_scale is not None:
loss_scale = loss_scale[:self.max_length]
else:
input_ids = input_ids[-self.max_length:]
if labels is not None:
labels = labels[-self.max_length:]
if loss_scale is not None:
loss_scale = loss_scale[-self.max_length:]

encoded['input_ids'] = input_ids
encoded['labels'] = labels
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_template(
*,
use_chat_template: bool = True,
template_backend: Literal['swift', 'jinja'] = 'swift',
truncation_strategy: Literal['raise', 'left'] = 'raise',
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
max_pixels: Optional[int] = None, # h * w
tools_prompt: str = 'react_en',
# train
Expand Down
Loading