diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2df9113e1e27ba..7dadb708d180f3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -19,7 +19,7 @@ from tqdm.auto import tqdm, trange from .data.data_collator import DataCollator, default_data_collator -from .file_utils import is_apex_available, is_torch_tpu_available +from .file_utils import is_torch_tpu_available from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup from .trainer_utils import ( @@ -33,8 +33,19 @@ from .training_args import TrainingArguments -if is_apex_available(): - from apex import amp +_use_native_amp = False +_use_apex = False + +# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex +if version.parse(torch.__version__) < version.parse("1.6"): + from transformers.file_utils import is_apex_available + + if is_apex_available(): + from apex import amp + _use_apex = True +else: + _use_native_amp = True + from torch.cuda.amp import autocast if is_torch_tpu_available(): @@ -225,6 +236,8 @@ def __init__( ), FutureWarning, ) + if self.args.fp16 and _use_native_amp: + self.scaler = torch.cuda.amp.GradScaler() def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset): @@ -428,7 +441,7 @@ def train(self, model_path: Optional[str] = None): scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model - if self.args.fp16: + if self.args.fp16 and _use_apex: if not is_apex_available(): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level) @@ -525,13 +538,20 @@ def train(self, model_path: Optional[str] = None): len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) ): - if self.args.fp16: + if self.args.fp16 and _use_native_amp: + self.scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) + elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(optimizer) + + if self.args.fp16 and _use_native_amp: + self.scaler.step(optimizer) + self.scaler.update() else: optimizer.step() @@ -699,19 +719,27 @@ def training_step( model.train() inputs = self._prepare_inputs(inputs, model) - outputs = model(**inputs) - # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs[0] + if self.args.fp16 and _use_native_amp: + with autocast(): + outputs = model(**inputs) + loss = outputs[0] + else: + outputs = model(**inputs) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs[0] if self.args.past_index >= 0: self._past = outputs[self.args.past_index] if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training + if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps - if self.args.fp16: + if self.args.fp16 and _use_native_amp: + self.scaler.scale(loss).backward() + elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: