diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index e4bd74720a..c3cc4579e5 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -28,6 +28,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm] logger = logging.get_logger(__name__) +is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True) is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True) is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 090ff41119..4ae3983c13 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -609,7 +609,7 @@ class TrainingArguments: - `"tpu_metrics_debug"`: print debug metrics on TPU The options should be separated by whitespaces. - optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`): + optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"` (for torch>=2.8 `"adamw_torch_fused"`)): The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision", "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) for a full list of optimizers. @@ -1280,11 +1280,11 @@ class TrainingArguments: ) default_optim = "adamw_torch" - # XXX: enable when pytorch==2.0.1 comes out - we want to give it time to get all the bugs sorted out - # if is_torch_available(): - # default_optim = "adamw_torch_fused" - # and update the doc above to: - # optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch_fused"` (for torch<2.1.0 `"adamw_torch"`): + if is_torch_available(): + from .pytorch_utils import is_torch_greater_or_equal_than_2_8 + + if is_torch_greater_or_equal_than_2_8: + default_optim = "adamw_torch_fused" optim: Union[OptimizerNames, str] = field( default=default_optim, metadata={"help": "The optimizer to use."}, diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index 38aa2b4e87..36683777b4 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -347,7 +347,8 @@ class GitModelTester: num_return_sequences=2, ) - self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20)) + self.parent.assertEqual(generated_ids.shape[0], self.batch_size * 2) + self.parent.assertTrue(generated_ids.shape[1] < 20) def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values): model = GitForCausalLM(config=config)