diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index b39e9f18f..4765e1a45 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -308,10 +308,10 @@ class BAdamArgument: class SwanLabArguments: use_swanlab: bool = field( default=False, - metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."}, + metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, ) swanlab_project: str = field( - default=None, + default="LLaMA Factory", metadata={"help": "The project name in SwanLab."}, ) swanlab_workspace: str = field( diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 7e76dee2d..0115f8344 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -31,7 +31,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -106,6 +106,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index e22b16a47..718023755 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -30,7 +30,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -101,6 +101,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 4ab7a1187..a60b7d7cd 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -40,7 +40,7 @@ from ...extras import logging from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -186,6 +186,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 37dcadfd9..5e4a513d2 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -20,7 +20,7 @@ from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -56,6 +56,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index bccfdef5d..458c40ff3 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -27,7 +27,7 @@ from ...extras import logging from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -68,6 +68,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index e267b63cf..8b5baade5 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1438,7 +1438,7 @@ }, "swanlab_experiment_name": { "en": { - "label": "Experiment_name(optional)", + "label": "Experiment name (optional)", }, "ru": { "label": "Имя эксперимента(Необязательный)",