diff --git a/requirements.txt b/requirements.txt index 61e1a9f90..283b5cc2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2 torchao==0.7.0 schedulefree==1.3.0 -axolotl-contribs-lgpl==0.0.1b2 +axolotl-contribs-lgpl==0.0.2 diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b09..a74ecc2ec 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,5 +1,6 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import inspect import os import signal import sys @@ -126,7 +127,20 @@ def train( ) if cfg.fix_untrained_tokens: - fix_untrained_tokens(model, tokenizer, train_dataset) + # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args + sig = inspect.signature(fix_untrained_tokens) + # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list + if "token_ids_to_fix" in sig.parameters and isinstance( + cfg.fix_untrained_tokens, list + ): + fix_untrained_tokens( + model, + tokenizer, + train_dataset, + token_ids_to_fix=cfg.fix_untrained_tokens, + ) + else: + fix_untrained_tokens(model, tokenizer, train_dataset) if cfg.local_rank == 0: model.save_pretrained( str(Path(cfg.output_dir)), safe_serialization=safe_serialization diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c704be800..0781c6798 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -794,7 +794,7 @@ class Config: chat_template_jinja: Optional[str] = None default_system_message: Optional[str] = None - fix_untrained_tokens: Optional[bool] = None + fix_untrained_tokens: Optional[Union[int, List[int]]] = None # INTERNALS - document for now, generally not set externally is_preprocess: Optional[bool] = None