Skip to content

Commit

Permalink
disable saving in HF Trainer integration
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed May 31, 2024
1 parent 2038014 commit 1871908
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion logix/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .patch import patch_trainer
from .arguments import LogIXArgument
from .arguments import LogIXArguments
2 changes: 1 addition & 1 deletion logix/huggingface/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class LogIXArgument:
class LogIXArguments:
project: str = field(
default="tmp_logix", metadata={"help": "The name of the project."}
)
Expand Down
4 changes: 2 additions & 2 deletions logix/huggingface/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from transformers.trainer import TrainerCallback

from logix import LogIX, LogIXScheduler
from logix.huggingface.arguments import LogIXArgument
from logix.huggingface.arguments import LogIXArguments


class LogIXCallback(TrainerCallback):
def __init__(
self,
logix: LogIX,
logix_scheduler: LogIXScheduler,
args: LogIXArgument,
args: LogIXArguments,
):
self.logix = logix
self.logix_scheduler = iter(logix_scheduler)
Expand Down
7 changes: 4 additions & 3 deletions logix/huggingface/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from logix import LogIX, LogIXScheduler
from logix.utils import DataIDGenerator
from logix.huggingface.callback import LogIXCallback
from logix.huggingface.arguments import LogIXArgument
from logix.huggingface.arguments import LogIXArguments


def patch_trainer(TrainerClass):
Expand All @@ -19,9 +19,9 @@ def patch_trainer(TrainerClass):
class PatchedTrainer(TrainerClass):
def __init__(
self,
logix_args: Optional[LogIXArgument] = None,
logix_args: LogIXArguments,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
args = TrainingArguments(output_dir=output_dir)
args.num_train_epochs = len(self.logix_scheduler)
args.report_to = []
args.save_strategy = "no"

super().__init__(
model,
Expand Down

0 comments on commit 1871908

Please sign in to comment.