Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pytorch Native AMP support in Trainer #6151

Merged
merged 7 commits into from
Jul 31, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tqdm.auto import tqdm, trange

from .data.data_collator import DataCollator, default_data_collator
from .file_utils import is_apex_available, is_torch_tpu_available
from .file_utils import is_torch_tpu_available
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import (
Expand All @@ -33,8 +33,19 @@
from .training_args import TrainingArguments


if is_apex_available():
from apex import amp
_use_native_amp = False
_use_apex = False

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
from transformers.file_utils import is_apex_available

if is_apex_available():
from apex import amp
_use_apex = True
else:
_use_native_amp = True
from torch.cuda.amp import autocast


if is_torch_tpu_available():
Expand Down Expand Up @@ -225,6 +236,8 @@ def __init__(
),
FutureWarning,
)
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()

def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
Expand Down Expand Up @@ -428,7 +441,7 @@ def train(self, model_path: Optional[str] = None):
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

model = self.model
if self.args.fp16:
if self.args.fp16 and _use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
Expand Down Expand Up @@ -525,13 +538,20 @@ def train(self, model_path: Optional[str] = None):
len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
):
if self.args.fp16:
if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

if is_torch_tpu_available():
xm.optimizer_step(optimizer)

if self.args.fp16 and _use_native_amp:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()

Expand Down Expand Up @@ -699,19 +719,27 @@ def training_step(
model.train()
inputs = self._prepare_inputs(inputs, model)

outputs = model(**inputs)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs[0]
if self.args.fp16 and _use_native_amp:
with autocast():
outputs = model(**inputs)
loss = outputs[0]
else:
outputs = model(**inputs)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs[0]

if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps

if self.args.fp16:
if self.args.fp16 and _use_native_amp:
self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
Expand Down