From 84c9cc6d1599e1a64ee73e14ce33727ec865baef Mon Sep 17 00:00:00 2001 From: atturaioe <76523524+atturaioe@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:27:08 +0200 Subject: [PATCH] Add AnyPrecisionAdamW optimizer (#18961) * Add AnyPrecisionAdamW optimizer * Add optim_args argument to TrainingArgs * Add tests for AnyPrecisionOptimizer * Change AnyPrecisionAdam default params to float32 * Move default_anyprecision_kwargs in trainer test * Rename AnyPrecisionAdamW --- src/transformers/trainer.py | 30 +++++++++ src/transformers/training_args.py | 6 +- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 4 ++ tests/trainer/test_trainer.py | 86 ++++++++++++++++++++------ 5 files changed, 108 insertions(+), 19 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 77e8935da2..46e7b78f4a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -29,6 +29,7 @@ import sys import time import warnings from collections.abc import Mapping +from distutils.util import strtobool from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -1081,7 +1082,16 @@ class Trainer: The training arguments for the training session. """ + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + optimizer_kwargs = {"lr": args.learning_rate} + adam_kwargs = { "betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon, @@ -1123,6 +1133,26 @@ class Trainer: optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") + elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: + try: + from torchdistx.optimizers import AnyPrecisionAdamW + + optimizer_cls = AnyPrecisionAdamW + optimizer_kwargs.update(adam_kwargs) + + # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. + optimizer_kwargs.update( + { + "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), + "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), + "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), + "compensation_buffer_dtype": getattr( + torch, optim_args.get("compensation_buffer_dtype", "bfloat16") + ), + } + ) + except ImportError: + raise ValueError("Please install https://github.com/pytorch/torchdistx") elif args.optim == OptimizerNames.SGD: optimizer_cls = torch.optim.SGD elif args.optim == OptimizerNames.ADAGRAD: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 0c3af0ae6f..60dc404d2a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -113,6 +113,7 @@ class OptimizerNames(ExplicitEnum): ADAMW_APEX_FUSED = "adamw_apex_fused" ADAFACTOR = "adafactor" ADAMW_BNB = "adamw_bnb_8bit" + ADAMW_ANYPRECISION = "adamw_anyprecision" SGD = "sgd" ADAGRAD = "adagrad" @@ -401,7 +402,9 @@ class TrainingArguments: The options should be separated by whitespaces. optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`): - The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. + The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, adamw_anyprecision or adafactor. + optim_args (`str`, *optional*): + Optional arguments that are supplied to AnyPrecisionAdamW. adafactor (`bool`, *optional*, defaults to `False`): This argument is deprecated. Use `--optim adafactor` instead. group_by_length (`bool`, *optional*, defaults to `False`): @@ -857,6 +860,7 @@ class TrainingArguments: default="adamw_hf", metadata={"help": "The optimizer to use."}, ) + optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."}) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) group_by_length: bool = field( default=False, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7701145bf6..8e2d62a04c 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -153,6 +153,7 @@ from .import_utils import ( is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, + is_torchdistx_available, is_torchdynamo_available, is_training_run_on_sagemaker, is_vision_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 6456fa4166..474b204170 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -508,6 +508,10 @@ def is_bitsandbytes_available(): return importlib.util.find_spec("bitsandbytes") is not None +def is_torchdistx_available(): + return importlib.util.find_spec("torchdistx") is not None + + def is_faiss_available(): return _faiss_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a8f4c11dcc..19016640c9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -71,7 +71,13 @@ from transformers.testing_utils import ( ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.training_args import OptimizerNames -from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available +from transformers.utils import ( + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + is_apex_available, + is_bitsandbytes_available, + is_torchdistx_available, +) from transformers.utils.hp_naming import TrialShortNamer @@ -2287,24 +2293,31 @@ if is_torch_available(): "lr": TrainingArguments.learning_rate, } + default_anyprecision_kwargs = { + "use_kahan_summation": False, + "momentum_dtype": torch.float32, + "variance_dtype": torch.float32, + "compensation_buffer_dtype": torch.bfloat16, + } + optim_test_params = [ ( - OptimizerNames.ADAMW_HF, + TrainingArguments(optim=OptimizerNames.ADAMW_HF, output_dir="None"), transformers.optimization.AdamW, default_adam_kwargs, ), ( - OptimizerNames.ADAMW_HF.value, + TrainingArguments(optim=OptimizerNames.ADAMW_HF.value, output_dir="None"), transformers.optimization.AdamW, default_adam_kwargs, ), ( - OptimizerNames.ADAMW_TORCH, + TrainingArguments(optim=OptimizerNames.ADAMW_TORCH, output_dir="None"), torch.optim.AdamW, default_adam_kwargs, ), ( - OptimizerNames.ADAFACTOR, + TrainingArguments(optim=OptimizerNames.ADAFACTOR, output_dir="None"), transformers.optimization.Adafactor, { "scale_parameter": False, @@ -2319,7 +2332,7 @@ if is_torch_available(): optim_test_params.append( ( - OptimizerNames.ADAMW_APEX_FUSED, + TrainingArguments(OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"), apex.optimizers.FusedAdam, default_adam_kwargs, ) @@ -2330,32 +2343,42 @@ if is_torch_available(): optim_test_params.append( ( - OptimizerNames.ADAMW_BNB, + TrainingArguments(optim=OptimizerNames.ADAMW_BNB, ouput_dir="None"), bnb.optim.Adam8bit, default_adam_kwargs, ) ) + if is_torchdistx_available(): + import torchdistx + + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"), + torchdistx.optimizers.AnyPrecisionAdamW, + dict(default_adam_kwargs, **default_anyprecision_kwargs), + ) + ) + @require_torch class TrainerOptimizerChoiceTest(unittest.TestCase): - def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls): - args = TrainingArguments(optim=optim, output_dir="None") - actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) + def check_optim_and_kwargs(self, training_args: TrainingArguments, expected_cls, expected_kwargs): + actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) self.assertEqual(expected_cls, actual_cls) self.assertIsNotNone(optim_kwargs) - for p, v in mandatory_kwargs.items(): + for p, v in expected_kwargs.items(): self.assertTrue(p in optim_kwargs) actual_v = optim_kwargs[p] self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.") @parameterized.expand(optim_test_params, skip_on_empty=True) - def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs): + def test_optim_supported(self, training_args: TrainingArguments, expected_cls, expected_kwargs): # exercises all the valid --optim options - self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls) + self.check_optim_and_kwargs(training_args, expected_cls, expected_kwargs) - trainer = get_regression_trainer(optim=name) + trainer = get_regression_trainer(**training_args.to_dict()) trainer.train() def test_fused_adam(self): @@ -2371,9 +2394,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): } with patch.dict("sys.modules", modules): self.check_optim_and_kwargs( - OptimizerNames.ADAMW_APEX_FUSED, - default_adam_kwargs, + TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"), mock.optimizers.FusedAdam, + default_adam_kwargs, ) def test_fused_adam_no_apex(self): @@ -2398,9 +2421,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): } with patch.dict("sys.modules", modules): self.check_optim_and_kwargs( - OptimizerNames.ADAMW_BNB, - default_adam_kwargs, + TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"), mock.optim.Adam8bit, + default_adam_kwargs, ) def test_bnb_adam8bit_no_bnb(self): @@ -2412,6 +2435,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): with self.assertRaises(ValueError): Trainer.get_optimizer_cls_and_kwargs(args) + def test_anyprecision_adamw(self): + # Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists. + # Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the + # class given, so mocking torchdistx.optimizers.AnyPrecisionAdamW should be fine for testing and allow + # the test to run without requiring a bnb installation. + mock = Mock() + modules = { + "torchdistx": mock, + "torchdistx.optimizers": mock.optimizers, + "torchdistx.optimizers.AnyPrecisionAdamW.": mock.optimizers.AnyPrecisionAdamW, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"), + mock.optimizers.AnyPrecisionAdamW, + dict(default_adam_kwargs, **default_anyprecision_kwargs), + ) + + def test_no_torchdistx_anyprecision_adamw(self): + args = TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None") + + # Pretend that torchdistx does not exist, even if installed. By setting torchdistx to None, importing + # torchdistx.optimizers will fail even if torchdistx is installed. + with patch.dict("sys.modules", {"torchdistx.optimizers": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + @require_torch @require_wandb