Deprecates AdamW and adds --optim (#14744)
* Add AdamW deprecation warning * Add --optim to Trainer * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py * fix style * fix * Regroup adamws together Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Change --adafactor to --optim adafactor * Use Enum for optimizer values * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Use Enum for optimizer values * Improved documentation for --adafactor Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Add mention of no_deprecation_warning Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename OptimizerOptions to OptimizerNames * Use choices for --optim * Move optimizer selection code to a function and add a unit test * Change optimizer names * Rename method Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename method Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Remove TODO comment Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename variable Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename variable Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename function * Rename variable * Parameterize the tests for supported optimizers * Refactor * Attempt to make tests pass on CircleCI * Add a test with apex * rework to add apex to parameterized; add actual train test * fix import when torch is not available * fix optim_test_params when torch is not available * fix optim_test_params when torch is not available * re-org * small re-org * fix test_fused_adam_no_apex * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove .value from OptimizerNames * Rename optimizer strings s|--adam_|--adamw_| * Also rename Enum options * small fix * Fix instantiation of OptimizerNames. Remove redundant test * Use ExplicitEnum instead of Enum * Add unit test with string optimizer * Change optimizer default to string value Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
committed by
GitHub
parent
762416ffa8
commit
7b83feb50a
@@ -23,10 +23,12 @@ import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -36,7 +38,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
@@ -61,6 +63,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
@@ -69,6 +72,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
import transformers.optimization
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
EarlyStoppingCallback,
|
||||
@@ -1711,3 +1715,98 @@ class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4
|
||||
)
|
||||
|
||||
|
||||
optim_test_params = []
|
||||
if is_torch_available():
|
||||
default_adam_kwargs = {
|
||||
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2),
|
||||
"eps": TrainingArguments.adam_epsilon,
|
||||
"lr": TrainingArguments.learning_rate,
|
||||
}
|
||||
|
||||
optim_test_params = [
|
||||
(
|
||||
OptimizerNames.ADAMW_HF,
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_HF.value,
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_TORCH,
|
||||
torch.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAFACTOR,
|
||||
transformers.optimization.Adafactor,
|
||||
{
|
||||
"scale_parameter": False,
|
||||
"relative_step": False,
|
||||
"lr": TrainingArguments.learning_rate,
|
||||
},
|
||||
),
|
||||
]
|
||||
if is_apex_available():
|
||||
import apex
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
OptimizerNames.ADAMW_APEX_FUSED,
|
||||
apex.optimizers.FusedAdam,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls):
|
||||
args = TrainingArguments(optim=optim, output_dir="None")
|
||||
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
self.assertEqual(expected_cls, actual_cls)
|
||||
self.assertIsNotNone(optim_kwargs)
|
||||
|
||||
for p, v in mandatory_kwargs.items():
|
||||
self.assertTrue(p in optim_kwargs)
|
||||
actual_v = optim_kwargs[p]
|
||||
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
|
||||
|
||||
@parameterized.expand(optim_test_params, skip_on_empty=True)
|
||||
def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):
|
||||
# exercises all the valid --optim options
|
||||
self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls)
|
||||
|
||||
trainer = get_regression_trainer(optim=name)
|
||||
trainer.train()
|
||||
|
||||
def test_fused_adam(self):
|
||||
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
|
||||
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
|
||||
# the test to run without requiring an apex installation.
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"apex": mock,
|
||||
"apex.optimizers": mock.optimizers,
|
||||
"apex.optimizers.FusedAdam": mock.optimizers.FusedAdam,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
OptimizerNames.ADAMW_APEX_FUSED,
|
||||
default_adam_kwargs,
|
||||
mock.optimizers.FusedAdam,
|
||||
)
|
||||
|
||||
def test_fused_adam_no_apex(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None")
|
||||
|
||||
# Pretend that apex does not exist, even if installed. By setting apex to None, importing
|
||||
# apex will fail even if apex is installed.
|
||||
with patch.dict("sys.modules", {"apex.optimizers": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
Reference in New Issue
Block a user