Add AnyPrecisionAdamW optimizer (#18961)
* Add AnyPrecisionAdamW optimizer * Add optim_args argument to TrainingArgs * Add tests for AnyPrecisionOptimizer * Change AnyPrecisionAdam default params to float32 * Move default_anyprecision_kwargs in trainer test * Rename AnyPrecisionAdamW
This commit is contained in:
@@ -71,7 +71,13 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
|
||||
from transformers.utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_torchdistx_available,
|
||||
)
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
@@ -2287,24 +2293,31 @@ if is_torch_available():
|
||||
"lr": TrainingArguments.learning_rate,
|
||||
}
|
||||
|
||||
default_anyprecision_kwargs = {
|
||||
"use_kahan_summation": False,
|
||||
"momentum_dtype": torch.float32,
|
||||
"variance_dtype": torch.float32,
|
||||
"compensation_buffer_dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
optim_test_params = [
|
||||
(
|
||||
OptimizerNames.ADAMW_HF,
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_HF, output_dir="None"),
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_HF.value,
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_HF.value, output_dir="None"),
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_TORCH,
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH, output_dir="None"),
|
||||
torch.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAFACTOR,
|
||||
TrainingArguments(optim=OptimizerNames.ADAFACTOR, output_dir="None"),
|
||||
transformers.optimization.Adafactor,
|
||||
{
|
||||
"scale_parameter": False,
|
||||
@@ -2319,7 +2332,7 @@ if is_torch_available():
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
OptimizerNames.ADAMW_APEX_FUSED,
|
||||
TrainingArguments(OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"),
|
||||
apex.optimizers.FusedAdam,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
@@ -2330,32 +2343,42 @@ if is_torch_available():
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
OptimizerNames.ADAMW_BNB,
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, ouput_dir="None"),
|
||||
bnb.optim.Adam8bit,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if is_torchdistx_available():
|
||||
import torchdistx
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"),
|
||||
torchdistx.optimizers.AnyPrecisionAdamW,
|
||||
dict(default_adam_kwargs, **default_anyprecision_kwargs),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls):
|
||||
args = TrainingArguments(optim=optim, output_dir="None")
|
||||
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
def check_optim_and_kwargs(self, training_args: TrainingArguments, expected_cls, expected_kwargs):
|
||||
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
self.assertEqual(expected_cls, actual_cls)
|
||||
self.assertIsNotNone(optim_kwargs)
|
||||
|
||||
for p, v in mandatory_kwargs.items():
|
||||
for p, v in expected_kwargs.items():
|
||||
self.assertTrue(p in optim_kwargs)
|
||||
actual_v = optim_kwargs[p]
|
||||
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
|
||||
|
||||
@parameterized.expand(optim_test_params, skip_on_empty=True)
|
||||
def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):
|
||||
def test_optim_supported(self, training_args: TrainingArguments, expected_cls, expected_kwargs):
|
||||
# exercises all the valid --optim options
|
||||
self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls)
|
||||
self.check_optim_and_kwargs(training_args, expected_cls, expected_kwargs)
|
||||
|
||||
trainer = get_regression_trainer(optim=name)
|
||||
trainer = get_regression_trainer(**training_args.to_dict())
|
||||
trainer.train()
|
||||
|
||||
def test_fused_adam(self):
|
||||
@@ -2371,9 +2394,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
OptimizerNames.ADAMW_APEX_FUSED,
|
||||
default_adam_kwargs,
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"),
|
||||
mock.optimizers.FusedAdam,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
|
||||
def test_fused_adam_no_apex(self):
|
||||
@@ -2398,9 +2421,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
OptimizerNames.ADAMW_BNB,
|
||||
default_adam_kwargs,
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
||||
mock.optim.Adam8bit,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_adam8bit_no_bnb(self):
|
||||
@@ -2412,6 +2435,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_anyprecision_adamw(self):
|
||||
# Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the
|
||||
# class given, so mocking torchdistx.optimizers.AnyPrecisionAdamW should be fine for testing and allow
|
||||
# the test to run without requiring a bnb installation.
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"torchdistx": mock,
|
||||
"torchdistx.optimizers": mock.optimizers,
|
||||
"torchdistx.optimizers.AnyPrecisionAdamW.": mock.optimizers.AnyPrecisionAdamW,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"),
|
||||
mock.optimizers.AnyPrecisionAdamW,
|
||||
dict(default_adam_kwargs, **default_anyprecision_kwargs),
|
||||
)
|
||||
|
||||
def test_no_torchdistx_anyprecision_adamw(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None")
|
||||
|
||||
# Pretend that torchdistx does not exist, even if installed. By setting torchdistx to None, importing
|
||||
# torchdistx.optimizers will fail even if torchdistx is installed.
|
||||
with patch.dict("sys.modules", {"torchdistx.optimizers": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_wandb
|
||||
|
||||
Reference in New Issue
Block a user