Add possibility to switch between APEX and AMP in Trainer (#9137)
* Add possibility to switch between APEX and AMP in Trainer * Update src/transformers/training_args.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments * Update src/transformers/training_args.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -93,8 +93,7 @@ from .training_args import TrainingArguments
|
|||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
_use_native_amp = False
|
_is_native_amp_available = False
|
||||||
_use_apex = False
|
|
||||||
|
|
||||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||||
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
||||||
@@ -110,16 +109,10 @@ if version.parse(torch.__version__) < version.parse("1.6"):
|
|||||||
|
|
||||||
if is_apex_available():
|
if is_apex_available():
|
||||||
from apex import amp
|
from apex import amp
|
||||||
_use_apex = True
|
|
||||||
else:
|
else:
|
||||||
_use_native_amp = True
|
_is_native_amp_available = True
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse("1.2"):
|
|
||||||
_use_ddp_no_sync = False
|
|
||||||
else:
|
|
||||||
_use_ddp_no_sync = True
|
|
||||||
|
|
||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
@@ -292,13 +285,30 @@ class Trainer:
|
|||||||
if isinstance(eval_dataset, datasets.Dataset):
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||||
|
|
||||||
|
# Mixed precision setup
|
||||||
|
self.use_apex = False
|
||||||
|
self.use_amp = False
|
||||||
|
if args.fp16:
|
||||||
|
if args.fp16_backend == "auto":
|
||||||
|
backend = "amp" if _is_native_amp_available else "apex"
|
||||||
|
else:
|
||||||
|
backend = args.fp16_backend
|
||||||
|
|
||||||
|
if backend == "amp":
|
||||||
|
self.use_amp = True
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
|
else:
|
||||||
|
if not is_apex_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
|
||||||
|
)
|
||||||
|
self.use_apex = True
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
||||||
# state at each call to self.log.
|
# state at each call to self.log.
|
||||||
self._total_flos = None
|
self._total_flos = None
|
||||||
if self.args.fp16 and _use_native_amp:
|
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
|
||||||
self.hp_search_backend = None
|
self.hp_search_backend = None
|
||||||
self.use_tune_checkpoints = False
|
self.use_tune_checkpoints = False
|
||||||
default_label_names = (
|
default_label_names = (
|
||||||
@@ -625,9 +635,7 @@ class Trainer:
|
|||||||
|
|
||||||
# Mixed precision training with apex (torch < 1.6)
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
model = self.model
|
model = self.model
|
||||||
if self.args.fp16 and _use_apex:
|
if self.use_apex:
|
||||||
if not is_apex_available():
|
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
||||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
|
|
||||||
# Multi-gpu training (should be after apex fp16 initialization)
|
# Multi-gpu training (should be after apex fp16 initialization)
|
||||||
@@ -756,11 +764,8 @@ class Trainer:
|
|||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
||||||
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
if (
|
if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1:
|
||||||
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
|
||||||
and self.args.local_rank != -1
|
|
||||||
and _use_ddp_no_sync
|
|
||||||
):
|
|
||||||
with model.no_sync():
|
with model.no_sync():
|
||||||
tr_loss += self.training_step(model, inputs)
|
tr_loss += self.training_step(model, inputs)
|
||||||
else:
|
else:
|
||||||
@@ -772,17 +777,17 @@ class Trainer:
|
|||||||
steps_in_epoch <= self.args.gradient_accumulation_steps
|
steps_in_epoch <= self.args.gradient_accumulation_steps
|
||||||
and (step + 1) == steps_in_epoch
|
and (step + 1) == steps_in_epoch
|
||||||
):
|
):
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.use_amp:
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||||
elif self.args.fp16 and _use_apex:
|
elif self.use_apex:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.optimizer_step(self.optimizer)
|
xm.optimizer_step(self.optimizer)
|
||||||
elif self.args.fp16 and _use_native_amp:
|
elif self.use_amp:
|
||||||
self.scaler.step(self.optimizer)
|
self.scaler.step(self.optimizer)
|
||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
else:
|
else:
|
||||||
@@ -1089,7 +1094,7 @@ class Trainer:
|
|||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.use_amp:
|
||||||
with autocast():
|
with autocast():
|
||||||
loss = self.compute_loss(model, inputs)
|
loss = self.compute_loss(model, inputs)
|
||||||
else:
|
else:
|
||||||
@@ -1101,9 +1106,9 @@ class Trainer:
|
|||||||
if self.args.gradient_accumulation_steps > 1:
|
if self.args.gradient_accumulation_steps > 1:
|
||||||
loss = loss / self.args.gradient_accumulation_steps
|
loss = loss / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.use_amp:
|
||||||
self.scaler.scale(loss).backward()
|
self.scaler.scale(loss).backward()
|
||||||
elif self.args.fp16 and _use_apex:
|
elif self.use_apex:
|
||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
@@ -1498,7 +1503,7 @@ class Trainer:
|
|||||||
ignore_keys = []
|
ignore_keys = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.use_amp:
|
||||||
with autocast():
|
with autocast():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -211,6 +211,10 @@ class TrainingArguments:
|
|||||||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||||
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
||||||
step can take a long time) but will not yield the same results as the interrupted training would have.
|
step can take a long time) but will not yield the same results as the interrupted training would have.
|
||||||
|
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
|
||||||
|
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
|
||||||
|
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
|
||||||
|
other choices will force the requested backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -378,6 +382,10 @@ class TrainingArguments:
|
|||||||
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
|
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
fp16_backend: str = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={"help": "The backend to be used for mixed precision. Should be one of 'auto', 'amp' or 'apex'."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
|
|||||||
@@ -798,34 +798,38 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_early_stopping_callback(self):
|
def test_early_stopping_callback(self):
|
||||||
# early stopping stops training before num_training_epochs
|
# early stopping stops training before num_training_epochs
|
||||||
trainer = get_regression_trainer(
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
num_train_epochs=20,
|
trainer = get_regression_trainer(
|
||||||
gradient_accumulation_steps=1,
|
output_dir=tmp_dir,
|
||||||
per_device_train_batch_size=16,
|
num_train_epochs=20,
|
||||||
load_best_model_at_end=True,
|
gradient_accumulation_steps=1,
|
||||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
per_device_train_batch_size=16,
|
||||||
compute_metrics=AlmostAccuracy(),
|
load_best_model_at_end=True,
|
||||||
metric_for_best_model="accuracy",
|
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||||
)
|
compute_metrics=AlmostAccuracy(),
|
||||||
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
|
metric_for_best_model="accuracy",
|
||||||
train_output = trainer.train()
|
)
|
||||||
self.assertLess(train_output.global_step, 20 * 64 / 16)
|
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
|
||||||
|
train_output = trainer.train()
|
||||||
|
self.assertLess(train_output.global_step, 20 * 64 / 16)
|
||||||
|
|
||||||
# Invalid inputs to trainer with early stopping callback result in assertion error
|
# Invalid inputs to trainer with early stopping callback result in assertion error
|
||||||
trainer = get_regression_trainer(
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
num_train_epochs=20,
|
trainer = get_regression_trainer(
|
||||||
gradient_accumulation_steps=1,
|
output_dir=tmp_dir,
|
||||||
per_device_train_batch_size=16,
|
num_train_epochs=20,
|
||||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
gradient_accumulation_steps=1,
|
||||||
compute_metrics=AlmostAccuracy(),
|
per_device_train_batch_size=16,
|
||||||
metric_for_best_model="accuracy",
|
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||||
)
|
compute_metrics=AlmostAccuracy(),
|
||||||
trainer.add_callback(EarlyStoppingCallback(1))
|
metric_for_best_model="accuracy",
|
||||||
self.assertEqual(trainer.state.global_step, 0)
|
)
|
||||||
try:
|
trainer.add_callback(EarlyStoppingCallback(1))
|
||||||
trainer.train()
|
|
||||||
except AssertionError:
|
|
||||||
self.assertEqual(trainer.state.global_step, 0)
|
self.assertEqual(trainer.state.global_step, 0)
|
||||||
|
try:
|
||||||
|
trainer.train()
|
||||||
|
except AssertionError:
|
||||||
|
self.assertEqual(trainer.state.global_step, 0)
|
||||||
|
|
||||||
def test_flos_extraction(self):
|
def test_flos_extraction(self):
|
||||||
trainer = get_regression_trainer(learning_rate=0.1)
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
|
|||||||
Reference in New Issue
Block a user