Skip to content

Commit

Permalink
fix untrained tokens if specified explicitly from a list (#2210)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Dec 23, 2024
1 parent d852d7a commit e0a2eb2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0

axolotl-contribs-lgpl==0.0.1b2
axolotl-contribs-lgpl==0.0.2
16 changes: 15 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""

import inspect
import os
import signal
import sys
Expand Down Expand Up @@ -126,7 +127,20 @@ def train(
)

if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ class Config:
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None

fix_untrained_tokens: Optional[bool] = None
fix_untrained_tokens: Optional[Union[int, List[int]]] = None

# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
Expand Down

0 comments on commit e0a2eb2

Please sign in to comment.