From ad895af98d9d69bfd3bf709225fb5920d4de8b86 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 15 Dec 2020 16:38:10 -0500 Subject: [PATCH] Add possibility to switch between APEX and AMP in Trainer (#9137) * Add possibility to switch between APEX and AMP in Trainer * Update src/transformers/training_args.py Co-authored-by: Stas Bekman * Address review comments * Update src/transformers/training_args.py Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- src/transformers/trainer.py | 57 +++++++++++++++++-------------- src/transformers/training_args.py | 8 +++++ tests/test_trainer.py | 54 +++++++++++++++-------------- 3 files changed, 68 insertions(+), 51 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 41f36917d9..102bf090ce 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -93,8 +93,7 @@ from .training_args import TrainingArguments from .utils import logging -_use_native_amp = False -_use_apex = False +_is_native_amp_available = False DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback @@ -110,16 +109,10 @@ if version.parse(torch.__version__) < version.parse("1.6"): if is_apex_available(): from apex import amp - _use_apex = True else: - _use_native_amp = True + _is_native_amp_available = True from torch.cuda.amp import autocast -if version.parse(torch.__version__) < version.parse("1.2"): - _use_ddp_no_sync = False -else: - _use_ddp_no_sync = True - if is_datasets_available(): import datasets @@ -292,13 +285,30 @@ class Trainer: if isinstance(eval_dataset, datasets.Dataset): self._remove_unused_columns(self.eval_dataset, description="evaluation") + # Mixed precision setup + self.use_apex = False + self.use_amp = False + if args.fp16: + if args.fp16_backend == "auto": + backend = "amp" if _is_native_amp_available else "apex" + else: + backend = args.fp16_backend + + if backend == "amp": + self.use_amp = True + self.scaler = torch.cuda.amp.GradScaler() + else: + if not is_apex_available(): + raise ImportError( + "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex." + ) + self.use_apex = True + self.state = TrainerState() self.control = TrainerControl() # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the # state at each call to self.log. self._total_flos = None - if self.args.fp16 and _use_native_amp: - self.scaler = torch.cuda.amp.GradScaler() self.hp_search_backend = None self.use_tune_checkpoints = False default_label_names = ( @@ -625,9 +635,7 @@ class Trainer: # Mixed precision training with apex (torch < 1.6) model = self.model - if self.args.fp16 and _use_apex: - if not is_apex_available(): - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + if self.use_apex: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) @@ -756,11 +764,8 @@ class Trainer: if (step + 1) % self.args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) - if ( - ((step + 1) % self.args.gradient_accumulation_steps != 0) - and self.args.local_rank != -1 - and _use_ddp_no_sync - ): + if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1: + # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): tr_loss += self.training_step(model, inputs) else: @@ -772,17 +777,17 @@ class Trainer: steps_in_epoch <= self.args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ): - if self.args.fp16 and _use_native_amp: + if self.use_amp: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) - elif self.args.fp16 and _use_apex: + elif self.use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(self.optimizer) - elif self.args.fp16 and _use_native_amp: + elif self.use_amp: self.scaler.step(self.optimizer) self.scaler.update() else: @@ -1089,7 +1094,7 @@ class Trainer: model.train() inputs = self._prepare_inputs(inputs) - if self.args.fp16 and _use_native_amp: + if self.use_amp: with autocast(): loss = self.compute_loss(model, inputs) else: @@ -1101,9 +1106,9 @@ class Trainer: if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps - if self.args.fp16 and _use_native_amp: + if self.use_amp: self.scaler.scale(loss).backward() - elif self.args.fp16 and _use_apex: + elif self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: @@ -1498,7 +1503,7 @@ class Trainer: ignore_keys = [] with torch.no_grad(): - if self.args.fp16 and _use_native_amp: + if self.use_amp: with autocast(): outputs = model(**inputs) else: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5de1dfbc07..60d80731ed 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -211,6 +211,10 @@ class TrainingArguments: When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have. + fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`): + The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or + :obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the + other choices will force the requested backend. """ output_dir: str = field( @@ -378,6 +382,10 @@ class TrainingArguments: "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." }, ) + fp16_backend: str = field( + default="auto", + metadata={"help": "The backend to be used for mixed precision. Should be one of 'auto', 'amp' or 'apex'."}, + ) def __post_init__(self): if self.disable_tqdm is None: diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e6fd44c37c..29a57ab07f 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -798,34 +798,38 @@ class TrainerIntegrationTest(unittest.TestCase): def test_early_stopping_callback(self): # early stopping stops training before num_training_epochs - trainer = get_regression_trainer( - num_train_epochs=20, - gradient_accumulation_steps=1, - per_device_train_batch_size=16, - load_best_model_at_end=True, - evaluation_strategy=EvaluationStrategy.EPOCH, - compute_metrics=AlmostAccuracy(), - metric_for_best_model="accuracy", - ) - trainer.add_callback(EarlyStoppingCallback(1, 0.0001)) - train_output = trainer.train() - self.assertLess(train_output.global_step, 20 * 64 / 16) + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=tmp_dir, + num_train_epochs=20, + gradient_accumulation_steps=1, + per_device_train_batch_size=16, + load_best_model_at_end=True, + evaluation_strategy=EvaluationStrategy.EPOCH, + compute_metrics=AlmostAccuracy(), + metric_for_best_model="accuracy", + ) + trainer.add_callback(EarlyStoppingCallback(1, 0.0001)) + train_output = trainer.train() + self.assertLess(train_output.global_step, 20 * 64 / 16) # Invalid inputs to trainer with early stopping callback result in assertion error - trainer = get_regression_trainer( - num_train_epochs=20, - gradient_accumulation_steps=1, - per_device_train_batch_size=16, - evaluation_strategy=EvaluationStrategy.EPOCH, - compute_metrics=AlmostAccuracy(), - metric_for_best_model="accuracy", - ) - trainer.add_callback(EarlyStoppingCallback(1)) - self.assertEqual(trainer.state.global_step, 0) - try: - trainer.train() - except AssertionError: + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=tmp_dir, + num_train_epochs=20, + gradient_accumulation_steps=1, + per_device_train_batch_size=16, + evaluation_strategy=EvaluationStrategy.EPOCH, + compute_metrics=AlmostAccuracy(), + metric_for_best_model="accuracy", + ) + trainer.add_callback(EarlyStoppingCallback(1)) self.assertEqual(trainer.state.global_step, 0) + try: + trainer.train() + except AssertionError: + self.assertEqual(trainer.state.global_step, 0) def test_flos_extraction(self): trainer = get_regression_trainer(learning_rate=0.1)