diff --git a/logix/huggingface/__init__.py b/logix/huggingface/__init__.py index 7fa738a..1df2835 100644 --- a/logix/huggingface/__init__.py +++ b/logix/huggingface/__init__.py @@ -1,2 +1,2 @@ from .patch import patch_trainer -from .arguments import LogIXArgument +from .arguments import LogIXArguments diff --git a/logix/huggingface/arguments.py b/logix/huggingface/arguments.py index 6d01794..8de39ab 100644 --- a/logix/huggingface/arguments.py +++ b/logix/huggingface/arguments.py @@ -5,7 +5,7 @@ @dataclass -class LogIXArgument: +class LogIXArguments: project: str = field( default="tmp_logix", metadata={"help": "The name of the project."} ) diff --git a/logix/huggingface/callback.py b/logix/huggingface/callback.py index 3d8fe12..0ad960b 100644 --- a/logix/huggingface/callback.py +++ b/logix/huggingface/callback.py @@ -3,7 +3,7 @@ from transformers.trainer import TrainerCallback from logix import LogIX, LogIXScheduler -from logix.huggingface.arguments import LogIXArgument +from logix.huggingface.arguments import LogIXArguments class LogIXCallback(TrainerCallback): @@ -11,7 +11,7 @@ def __init__( self, logix: LogIX, logix_scheduler: LogIXScheduler, - args: LogIXArgument, + args: LogIXArguments, ): self.logix = logix self.logix_scheduler = iter(logix_scheduler) diff --git a/logix/huggingface/patch.py b/logix/huggingface/patch.py index 3acb01e..4678769 100644 --- a/logix/huggingface/patch.py +++ b/logix/huggingface/patch.py @@ -5,7 +5,7 @@ from logix import LogIX, LogIXScheduler from logix.utils import DataIDGenerator from logix.huggingface.callback import LogIXCallback -from logix.huggingface.arguments import LogIXArgument +from logix.huggingface.arguments import LogIXArguments def patch_trainer(TrainerClass): @@ -19,9 +19,9 @@ def patch_trainer(TrainerClass): class PatchedTrainer(TrainerClass): def __init__( self, - logix_args: Optional[LogIXArgument] = None, + logix_args: LogIXArguments, model: Union[PreTrainedModel, nn.Module] = None, - args: TrainingArguments = None, + args: Optional[TrainingArguments] = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, @@ -59,6 +59,7 @@ def __init__( args = TrainingArguments(output_dir=output_dir) args.num_train_epochs = len(self.logix_scheduler) args.report_to = [] + args.save_strategy = "no" super().__init__( model,