From 3e93dd295b5343557a83bc07b0b2ea64c926f9b4 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 12 Oct 2023 18:22:09 +0200 Subject: [PATCH] Skip `TrainerIntegrationFSDP::test_basic_run_with_cpu_offload` if `torch < 2.1` (#26764) * fix * fix --------- Co-authored-by: ydshieh --- src/transformers/pytorch_utils.py | 1 + tests/fsdp/test_fsdp.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 73f4176d4b..d0bc55fe83 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11") diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index f9dd300626..69103dcd8c 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -14,6 +14,7 @@ import itertools import os +import unittest from functools import partial from parameterized import parameterized @@ -37,6 +38,11 @@ from transformers.trainer_utils import FSDPOption, set_seed from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available +if is_torch_available(): + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1 +else: + is_torch_greater_or_equal_than_2_1 = False + # default torch.distributed port DEFAULT_MASTER_PORT = "10999" dtypes = ["fp16"] @@ -178,6 +184,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): @parameterized.expand(dtypes) @require_torch_multi_gpu @slow + @unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.") def test_basic_run_with_cpu_offload(self, dtype): launcher = get_launcher(distributed=True, use_accelerate=False) output_dir = self.get_auto_remove_tmp_dir()