From 7b83feb50a8965e9d8f13b6c4042239710b97c76 Mon Sep 17 00:00:00 2001 From: "Manuel R. Ciosici" Date: Thu, 13 Jan 2022 11:14:51 -0500 Subject: [PATCH] Deprecates AdamW and adds `--optim` (#14744) * Add AdamW deprecation warning * Add --optim to Trainer * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py * fix style * fix * Regroup adamws together Co-authored-by: Stas Bekman * Change --adafactor to --optim adafactor * Use Enum for optimizer values * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Use Enum for optimizer values * Improved documentation for --adafactor Co-authored-by: Stas Bekman * Add mention of no_deprecation_warning Co-authored-by: Stas Bekman * Rename OptimizerOptions to OptimizerNames * Use choices for --optim * Move optimizer selection code to a function and add a unit test * Change optimizer names * Rename method Co-authored-by: Stas Bekman * Rename method Co-authored-by: Stas Bekman * Remove TODO comment Co-authored-by: Stas Bekman * Rename variable Co-authored-by: Stas Bekman * Rename variable Co-authored-by: Stas Bekman * Rename function * Rename variable * Parameterize the tests for supported optimizers * Refactor * Attempt to make tests pass on CircleCI * Add a test with apex * rework to add apex to parameterized; add actual train test * fix import when torch is not available * fix optim_test_params when torch is not available * fix optim_test_params when torch is not available * re-org * small re-org * fix test_fused_adam_no_apex * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove .value from OptimizerNames * Rename optimizer strings s|--adam_|--adamw_| * Also rename Enum options * small fix * Fix instantiation of OptimizerNames. Remove redundant test * Use ExplicitEnum instead of Enum * Add unit test with string optimizer * Change optimizer default to string value Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- src/transformers/optimization.py | 10 +++ src/transformers/trainer.py | 58 +++++++++++++---- src/transformers/training_args.py | 29 ++++++++- tests/test_trainer.py | 101 +++++++++++++++++++++++++++++- 4 files changed, 183 insertions(+), 15 deletions(-) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 124ce7f086..269e767e93 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -15,6 +15,7 @@ """PyTorch optimization for BERT model.""" import math +import warnings from typing import Callable, Iterable, Optional, Tuple, Union import torch @@ -287,6 +288,8 @@ class AdamW(Optimizer): Decoupled weight decay to apply. correct_bias (`bool`, *optional*, defaults to `True`): Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). """ def __init__( @@ -297,7 +300,14 @@ class AdamW(Optimizer): eps: float = 1e-6, weight_decay: float = 0.0, correct_bias: bool = True, + no_deprecation_warning: bool = False, ): + if not no_deprecation_warning: + warnings.warn( + "This implementation of AdamW is deprecated and will be removed in a future version. Use the" + "PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning", + FutureWarning, + ) require_version("torch>=1.5.0") # add_ with alpha if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1b382c5171..a44331ae23 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -77,7 +77,7 @@ from .file_utils import ( from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES -from .optimization import Adafactor, AdamW, get_scheduler +from .optimization import Adafactor, get_scheduler from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( CallbackHandler, @@ -128,7 +128,7 @@ from .trainer_utils import ( set_seed, speed_metrics, ) -from .training_args import ParallelMode, TrainingArguments +from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .utils import logging @@ -819,17 +819,9 @@ class Trainer: "weight_decay": 0.0, }, ] - optimizer_cls = Adafactor if self.args.adafactor else AdamW - if self.args.adafactor: - optimizer_cls = Adafactor - optimizer_kwargs = {"scale_parameter": False, "relative_step": False} - else: - optimizer_cls = AdamW - optimizer_kwargs = { - "betas": (self.args.adam_beta1, self.args.adam_beta2), - "eps": self.args.adam_epsilon, - } - optimizer_kwargs["lr"] = self.args.learning_rate + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer = OSS( params=optimizer_grouped_parameters, @@ -844,6 +836,46 @@ class Trainer: return self.optimizer + @staticmethod + def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: + """ + Returns the optimizer class and optimizer parameters based on the training arguments. + + Args: + args (`transformers.training_args.TrainingArguments`): + The training arguments for the training session. + + """ + optimizer_kwargs = {"lr": args.learning_rate} + adam_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + if args.optim == OptimizerNames.ADAFACTOR: + optimizer_cls = Adafactor + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim == OptimizerNames.ADAMW_HF: + from .optimization import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif args.optim == OptimizerNames.ADAMW_TORCH: + from torch.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: + try: + from apex.optimizers import FusedAdam + + optimizer_cls = FusedAdam + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + else: + raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") + return optimizer_cls, optimizer_kwargs + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): """ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 74eaef4dcb..1afe93c372 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional from .debug_utils import DebugOption from .file_utils import ( + ExplicitEnum, cached_property, get_full_repo_name, is_sagemaker_dp_enabled, @@ -69,6 +70,17 @@ def default_logdir() -> str: return os.path.join("runs", current_time + "_" + socket.gethostname()) +class OptimizerNames(ExplicitEnum): + """ + Stores the acceptable string identifiers for optimizers. + """ + + ADAMW_HF = "adamw_hf" + ADAMW_TORCH = "adamw_torch" + ADAMW_APEX_FUSED = "adamw_apex_fused" + ADAFACTOR = "adafactor" + + @dataclass class TrainingArguments: """ @@ -327,8 +339,10 @@ class TrainingArguments: - `"tpu_metrics_debug"`: print debug metrics on TPU The options should be separated by whitespaces. + optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`): + The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. adafactor (`bool`, *optional*, defaults to `False`): - Whether or not to use the [`Adafactor`] optimizer instead of [`AdamW`]. + This argument is deprecated. Use `--optim adafactor` instead. group_by_length (`bool`, *optional*, defaults to `False`): Whether or not to group together samples of roughly the same length in the training dataset (to minimize padding applied and be more efficient). Only useful if applying dynamic padding. @@ -641,6 +655,10 @@ class TrainingArguments: label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) + optim: OptimizerNames = field( + default="adamw_hf", + metadata={"help": "The optimizer to use."}, + ) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) group_by_length: bool = field( default=False, @@ -809,6 +827,15 @@ class TrainingArguments: ) if not (self.sharded_ddp == "" or not self.sharded_ddp): raise ValueError("sharded_ddp is not supported with bf16") + + self.optim = OptimizerNames(self.optim) + if self.adafactor: + warnings.warn( + "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim adafactor` instead", + FutureWarning, + ) + self.optim = OptimizerNames.ADAFACTOR + if ( is_torch_available() and self.device.type != "cuda" diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e16aefdc8c..f7cb287265 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -23,10 +23,12 @@ import subprocess import tempfile import unittest from pathlib import Path +from unittest.mock import Mock, patch import numpy as np from huggingface_hub import Repository, delete_repo, login +from parameterized import parameterized from requests.exceptions import HTTPError from transformers import ( AutoTokenizer, @@ -36,7 +38,7 @@ from transformers import ( is_torch_available, logging, ) -from transformers.file_utils import WEIGHTS_NAME +from transformers.file_utils import WEIGHTS_NAME, is_apex_available from transformers.testing_utils import ( ENDPOINT_STAGING, PASS, @@ -61,6 +63,7 @@ from transformers.testing_utils import ( slow, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.training_args import OptimizerNames from transformers.utils.hp_naming import TrialShortNamer @@ -69,6 +72,7 @@ if is_torch_available(): from torch import nn from torch.utils.data import IterableDataset + import transformers.optimization from transformers import ( AutoModelForSequenceClassification, EarlyStoppingCallback, @@ -1711,3 +1715,98 @@ class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase): trainer.hyperparameter_search( direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4 ) + + +optim_test_params = [] +if is_torch_available(): + default_adam_kwargs = { + "betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2), + "eps": TrainingArguments.adam_epsilon, + "lr": TrainingArguments.learning_rate, + } + + optim_test_params = [ + ( + OptimizerNames.ADAMW_HF, + transformers.optimization.AdamW, + default_adam_kwargs, + ), + ( + OptimizerNames.ADAMW_HF.value, + transformers.optimization.AdamW, + default_adam_kwargs, + ), + ( + OptimizerNames.ADAMW_TORCH, + torch.optim.AdamW, + default_adam_kwargs, + ), + ( + OptimizerNames.ADAFACTOR, + transformers.optimization.Adafactor, + { + "scale_parameter": False, + "relative_step": False, + "lr": TrainingArguments.learning_rate, + }, + ), + ] + if is_apex_available(): + import apex + + optim_test_params.append( + ( + OptimizerNames.ADAMW_APEX_FUSED, + apex.optimizers.FusedAdam, + default_adam_kwargs, + ) + ) + + +@require_torch +class TrainerOptimizerChoiceTest(unittest.TestCase): + def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls): + args = TrainingArguments(optim=optim, output_dir="None") + actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) + self.assertEqual(expected_cls, actual_cls) + self.assertIsNotNone(optim_kwargs) + + for p, v in mandatory_kwargs.items(): + self.assertTrue(p in optim_kwargs) + actual_v = optim_kwargs[p] + self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.") + + @parameterized.expand(optim_test_params, skip_on_empty=True) + def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs): + # exercises all the valid --optim options + self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls) + + trainer = get_regression_trainer(optim=name) + trainer.train() + + def test_fused_adam(self): + # Pretend that apex is installed and mock apex.optimizers.FusedAdam exists. + # Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a + # class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow + # the test to run without requiring an apex installation. + mock = Mock() + modules = { + "apex": mock, + "apex.optimizers": mock.optimizers, + "apex.optimizers.FusedAdam": mock.optimizers.FusedAdam, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + OptimizerNames.ADAMW_APEX_FUSED, + default_adam_kwargs, + mock.optimizers.FusedAdam, + ) + + def test_fused_adam_no_apex(self): + args = TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None") + + # Pretend that apex does not exist, even if installed. By setting apex to None, importing + # apex will fail even if apex is installed. + with patch.dict("sys.modules", {"apex.optimizers": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args)