Add possibility to switch between APEX and AMP in Trainer (#9137)
* Add possibility to switch between APEX and AMP in Trainer * Update src/transformers/training_args.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments * Update src/transformers/training_args.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -93,8 +93,7 @@ from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
|
||||
_use_native_amp = False
|
||||
_use_apex = False
|
||||
_is_native_amp_available = False
|
||||
|
||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
||||
@@ -110,16 +109,10 @@ if version.parse(torch.__version__) < version.parse("1.6"):
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
_use_apex = True
|
||||
else:
|
||||
_use_native_amp = True
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.2"):
|
||||
_use_ddp_no_sync = False
|
||||
else:
|
||||
_use_ddp_no_sync = True
|
||||
|
||||
if is_datasets_available():
|
||||
import datasets
|
||||
|
||||
@@ -292,13 +285,30 @@ class Trainer:
|
||||
if isinstance(eval_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||
|
||||
# Mixed precision setup
|
||||
self.use_apex = False
|
||||
self.use_amp = False
|
||||
if args.fp16:
|
||||
if args.fp16_backend == "auto":
|
||||
backend = "amp" if _is_native_amp_available else "apex"
|
||||
else:
|
||||
backend = args.fp16_backend
|
||||
|
||||
if backend == "amp":
|
||||
self.use_amp = True
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
else:
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
"Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
|
||||
)
|
||||
self.use_apex = True
|
||||
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
||||
# state at each call to self.log.
|
||||
self._total_flos = None
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
self.hp_search_backend = None
|
||||
self.use_tune_checkpoints = False
|
||||
default_label_names = (
|
||||
@@ -625,9 +635,7 @@ class Trainer:
|
||||
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
model = self.model
|
||||
if self.args.fp16 and _use_apex:
|
||||
if not is_apex_available():
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
if self.use_apex:
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
# Multi-gpu training (should be after apex fp16 initialization)
|
||||
@@ -756,11 +764,8 @@ class Trainer:
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
||||
|
||||
if (
|
||||
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
||||
and self.args.local_rank != -1
|
||||
and _use_ddp_no_sync
|
||||
):
|
||||
if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1:
|
||||
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
|
||||
with model.no_sync():
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
else:
|
||||
@@ -772,17 +777,17 @@ class Trainer:
|
||||
steps_in_epoch <= self.args.gradient_accumulation_steps
|
||||
and (step + 1) == steps_in_epoch
|
||||
):
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
if self.use_amp:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||
elif self.args.fp16 and _use_apex:
|
||||
elif self.use_apex:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.optimizer_step(self.optimizer)
|
||||
elif self.args.fp16 and _use_native_amp:
|
||||
elif self.use_amp:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
@@ -1089,7 +1094,7 @@ class Trainer:
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
@@ -1101,9 +1106,9 @@ class Trainer:
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
if self.use_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
elif self.args.fp16 and _use_apex:
|
||||
elif self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
@@ -1498,7 +1503,7 @@ class Trainer:
|
||||
ignore_keys = []
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
outputs = model(**inputs)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user