Defaults to adamw_torch_fused for Pytorch>=2.8 (#37358)
* Defaults to adamw_torch_fused for latest Pytorch Signed-off-by: cyy <cyyever@outlook.com> * Fix test Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -28,6 +28,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
|
||||
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
|
||||
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
|
||||
|
||||
@@ -609,7 +609,7 @@ class TrainingArguments:
|
||||
- `"tpu_metrics_debug"`: print debug metrics on TPU
|
||||
|
||||
The options should be separated by whitespaces.
|
||||
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
|
||||
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"` (for torch>=2.8 `"adamw_torch_fused"`)):
|
||||
The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
|
||||
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
|
||||
for a full list of optimizers.
|
||||
@@ -1280,11 +1280,11 @@ class TrainingArguments:
|
||||
)
|
||||
|
||||
default_optim = "adamw_torch"
|
||||
# XXX: enable when pytorch==2.0.1 comes out - we want to give it time to get all the bugs sorted out
|
||||
# if is_torch_available():
|
||||
# default_optim = "adamw_torch_fused"
|
||||
# and update the doc above to:
|
||||
# optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch_fused"` (for torch<2.1.0 `"adamw_torch"`):
|
||||
if is_torch_available():
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_8
|
||||
|
||||
if is_torch_greater_or_equal_than_2_8:
|
||||
default_optim = "adamw_torch_fused"
|
||||
optim: Union[OptimizerNames, str] = field(
|
||||
default=default_optim,
|
||||
metadata={"help": "The optimizer to use."},
|
||||
|
||||
@@ -347,7 +347,8 @@ class GitModelTester:
|
||||
num_return_sequences=2,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||
self.parent.assertEqual(generated_ids.shape[0], self.batch_size * 2)
|
||||
self.parent.assertTrue(generated_ids.shape[1] < 20)
|
||||
|
||||
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
|
||||
model = GitForCausalLM(config=config)
|
||||
|
||||
Reference in New Issue
Block a user