Add AnyPrecisionAdamW optimizer (#18961)
* Add AnyPrecisionAdamW optimizer * Add optim_args argument to TrainingArgs * Add tests for AnyPrecisionOptimizer * Change AnyPrecisionAdam default params to float32 * Move default_anyprecision_kwargs in trainer test * Rename AnyPrecisionAdamW
This commit is contained in:
@@ -29,6 +29,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from distutils.util import strtobool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -1081,7 +1082,16 @@ class Trainer:
|
|||||||
The training arguments for the training session.
|
The training arguments for the training session.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# parse args.optim_args
|
||||||
|
optim_args = {}
|
||||||
|
if args.optim_args:
|
||||||
|
for mapping in args.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optim_args[key] = value
|
||||||
|
|
||||||
optimizer_kwargs = {"lr": args.learning_rate}
|
optimizer_kwargs = {"lr": args.learning_rate}
|
||||||
|
|
||||||
adam_kwargs = {
|
adam_kwargs = {
|
||||||
"betas": (args.adam_beta1, args.adam_beta2),
|
"betas": (args.adam_beta1, args.adam_beta2),
|
||||||
"eps": args.adam_epsilon,
|
"eps": args.adam_epsilon,
|
||||||
@@ -1123,6 +1133,26 @@ class Trainer:
|
|||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
|
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
|
||||||
|
elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
|
||||||
|
try:
|
||||||
|
from torchdistx.optimizers import AnyPrecisionAdamW
|
||||||
|
|
||||||
|
optimizer_cls = AnyPrecisionAdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
# TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
|
||||||
|
optimizer_kwargs.update(
|
||||||
|
{
|
||||||
|
"use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
|
||||||
|
"momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
|
||||||
|
"variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
|
||||||
|
"compensation_buffer_dtype": getattr(
|
||||||
|
torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError("Please install https://github.com/pytorch/torchdistx")
|
||||||
elif args.optim == OptimizerNames.SGD:
|
elif args.optim == OptimizerNames.SGD:
|
||||||
optimizer_cls = torch.optim.SGD
|
optimizer_cls = torch.optim.SGD
|
||||||
elif args.optim == OptimizerNames.ADAGRAD:
|
elif args.optim == OptimizerNames.ADAGRAD:
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||||
ADAFACTOR = "adafactor"
|
ADAFACTOR = "adafactor"
|
||||||
ADAMW_BNB = "adamw_bnb_8bit"
|
ADAMW_BNB = "adamw_bnb_8bit"
|
||||||
|
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||||
SGD = "sgd"
|
SGD = "sgd"
|
||||||
ADAGRAD = "adagrad"
|
ADAGRAD = "adagrad"
|
||||||
|
|
||||||
@@ -401,7 +402,9 @@ class TrainingArguments:
|
|||||||
|
|
||||||
The options should be separated by whitespaces.
|
The options should be separated by whitespaces.
|
||||||
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
|
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
|
||||||
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor.
|
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, adamw_anyprecision or adafactor.
|
||||||
|
optim_args (`str`, *optional*):
|
||||||
|
Optional arguments that are supplied to AnyPrecisionAdamW.
|
||||||
adafactor (`bool`, *optional*, defaults to `False`):
|
adafactor (`bool`, *optional*, defaults to `False`):
|
||||||
This argument is deprecated. Use `--optim adafactor` instead.
|
This argument is deprecated. Use `--optim adafactor` instead.
|
||||||
group_by_length (`bool`, *optional*, defaults to `False`):
|
group_by_length (`bool`, *optional*, defaults to `False`):
|
||||||
@@ -857,6 +860,7 @@ class TrainingArguments:
|
|||||||
default="adamw_hf",
|
default="adamw_hf",
|
||||||
metadata={"help": "The optimizer to use."},
|
metadata={"help": "The optimizer to use."},
|
||||||
)
|
)
|
||||||
|
optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."})
|
||||||
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
group_by_length: bool = field(
|
group_by_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ from .import_utils import (
|
|||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
|
is_torchdistx_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
is_training_run_on_sagemaker,
|
is_training_run_on_sagemaker,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
|
|||||||
@@ -508,6 +508,10 @@ def is_bitsandbytes_available():
|
|||||||
return importlib.util.find_spec("bitsandbytes") is not None
|
return importlib.util.find_spec("bitsandbytes") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchdistx_available():
|
||||||
|
return importlib.util.find_spec("torchdistx") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_faiss_available():
|
def is_faiss_available():
|
||||||
return _faiss_available
|
return _faiss_available
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,13 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
|
from transformers.utils import (
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_apex_available,
|
||||||
|
is_bitsandbytes_available,
|
||||||
|
is_torchdistx_available,
|
||||||
|
)
|
||||||
from transformers.utils.hp_naming import TrialShortNamer
|
from transformers.utils.hp_naming import TrialShortNamer
|
||||||
|
|
||||||
|
|
||||||
@@ -2287,24 +2293,31 @@ if is_torch_available():
|
|||||||
"lr": TrainingArguments.learning_rate,
|
"lr": TrainingArguments.learning_rate,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
default_anyprecision_kwargs = {
|
||||||
|
"use_kahan_summation": False,
|
||||||
|
"momentum_dtype": torch.float32,
|
||||||
|
"variance_dtype": torch.float32,
|
||||||
|
"compensation_buffer_dtype": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
optim_test_params = [
|
optim_test_params = [
|
||||||
(
|
(
|
||||||
OptimizerNames.ADAMW_HF,
|
TrainingArguments(optim=OptimizerNames.ADAMW_HF, output_dir="None"),
|
||||||
transformers.optimization.AdamW,
|
transformers.optimization.AdamW,
|
||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
OptimizerNames.ADAMW_HF.value,
|
TrainingArguments(optim=OptimizerNames.ADAMW_HF.value, output_dir="None"),
|
||||||
transformers.optimization.AdamW,
|
transformers.optimization.AdamW,
|
||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
OptimizerNames.ADAMW_TORCH,
|
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH, output_dir="None"),
|
||||||
torch.optim.AdamW,
|
torch.optim.AdamW,
|
||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
OptimizerNames.ADAFACTOR,
|
TrainingArguments(optim=OptimizerNames.ADAFACTOR, output_dir="None"),
|
||||||
transformers.optimization.Adafactor,
|
transformers.optimization.Adafactor,
|
||||||
{
|
{
|
||||||
"scale_parameter": False,
|
"scale_parameter": False,
|
||||||
@@ -2319,7 +2332,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
optim_test_params.append(
|
optim_test_params.append(
|
||||||
(
|
(
|
||||||
OptimizerNames.ADAMW_APEX_FUSED,
|
TrainingArguments(OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"),
|
||||||
apex.optimizers.FusedAdam,
|
apex.optimizers.FusedAdam,
|
||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
)
|
)
|
||||||
@@ -2330,32 +2343,42 @@ if is_torch_available():
|
|||||||
|
|
||||||
optim_test_params.append(
|
optim_test_params.append(
|
||||||
(
|
(
|
||||||
OptimizerNames.ADAMW_BNB,
|
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, ouput_dir="None"),
|
||||||
bnb.optim.Adam8bit,
|
bnb.optim.Adam8bit,
|
||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_torchdistx_available():
|
||||||
|
import torchdistx
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"),
|
||||||
|
torchdistx.optimizers.AnyPrecisionAdamW,
|
||||||
|
dict(default_adam_kwargs, **default_anyprecision_kwargs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||||
def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls):
|
def check_optim_and_kwargs(self, training_args: TrainingArguments, expected_cls, expected_kwargs):
|
||||||
args = TrainingArguments(optim=optim, output_dir="None")
|
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||||
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
|
|
||||||
self.assertEqual(expected_cls, actual_cls)
|
self.assertEqual(expected_cls, actual_cls)
|
||||||
self.assertIsNotNone(optim_kwargs)
|
self.assertIsNotNone(optim_kwargs)
|
||||||
|
|
||||||
for p, v in mandatory_kwargs.items():
|
for p, v in expected_kwargs.items():
|
||||||
self.assertTrue(p in optim_kwargs)
|
self.assertTrue(p in optim_kwargs)
|
||||||
actual_v = optim_kwargs[p]
|
actual_v = optim_kwargs[p]
|
||||||
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
|
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
|
||||||
|
|
||||||
@parameterized.expand(optim_test_params, skip_on_empty=True)
|
@parameterized.expand(optim_test_params, skip_on_empty=True)
|
||||||
def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):
|
def test_optim_supported(self, training_args: TrainingArguments, expected_cls, expected_kwargs):
|
||||||
# exercises all the valid --optim options
|
# exercises all the valid --optim options
|
||||||
self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls)
|
self.check_optim_and_kwargs(training_args, expected_cls, expected_kwargs)
|
||||||
|
|
||||||
trainer = get_regression_trainer(optim=name)
|
trainer = get_regression_trainer(**training_args.to_dict())
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
def test_fused_adam(self):
|
def test_fused_adam(self):
|
||||||
@@ -2371,9 +2394,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
with patch.dict("sys.modules", modules):
|
with patch.dict("sys.modules", modules):
|
||||||
self.check_optim_and_kwargs(
|
self.check_optim_and_kwargs(
|
||||||
OptimizerNames.ADAMW_APEX_FUSED,
|
TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"),
|
||||||
default_adam_kwargs,
|
|
||||||
mock.optimizers.FusedAdam,
|
mock.optimizers.FusedAdam,
|
||||||
|
default_adam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fused_adam_no_apex(self):
|
def test_fused_adam_no_apex(self):
|
||||||
@@ -2398,9 +2421,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
with patch.dict("sys.modules", modules):
|
with patch.dict("sys.modules", modules):
|
||||||
self.check_optim_and_kwargs(
|
self.check_optim_and_kwargs(
|
||||||
OptimizerNames.ADAMW_BNB,
|
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
||||||
default_adam_kwargs,
|
|
||||||
mock.optim.Adam8bit,
|
mock.optim.Adam8bit,
|
||||||
|
default_adam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_bnb_adam8bit_no_bnb(self):
|
def test_bnb_adam8bit_no_bnb(self):
|
||||||
@@ -2412,6 +2435,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_anyprecision_adamw(self):
|
||||||
|
# Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists.
|
||||||
|
# Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the
|
||||||
|
# class given, so mocking torchdistx.optimizers.AnyPrecisionAdamW should be fine for testing and allow
|
||||||
|
# the test to run without requiring a bnb installation.
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"torchdistx": mock,
|
||||||
|
"torchdistx.optimizers": mock.optimizers,
|
||||||
|
"torchdistx.optimizers.AnyPrecisionAdamW.": mock.optimizers.AnyPrecisionAdamW,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"),
|
||||||
|
mock.optimizers.AnyPrecisionAdamW,
|
||||||
|
dict(default_adam_kwargs, **default_anyprecision_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_torchdistx_anyprecision_adamw(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that torchdistx does not exist, even if installed. By setting torchdistx to None, importing
|
||||||
|
# torchdistx.optimizers will fail even if torchdistx is installed.
|
||||||
|
with patch.dict("sys.modules", {"torchdistx.optimizers": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_wandb
|
@require_wandb
|
||||||
|
|||||||
Reference in New Issue
Block a user