Skip TrainerIntegrationFSDP::test_basic_run_with_cpu_offload if torch < 2.1 (#26764)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||||
|
|
||||||
|
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||||
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
|
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@@ -37,6 +38,11 @@ from transformers.trainer_utils import FSDPOption, set_seed
|
|||||||
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
|
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
|
||||||
|
else:
|
||||||
|
is_torch_greater_or_equal_than_2_1 = False
|
||||||
|
|
||||||
# default torch.distributed port
|
# default torch.distributed port
|
||||||
DEFAULT_MASTER_PORT = "10999"
|
DEFAULT_MASTER_PORT = "10999"
|
||||||
dtypes = ["fp16"]
|
dtypes = ["fp16"]
|
||||||
@@ -178,6 +184,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
@parameterized.expand(dtypes)
|
@parameterized.expand(dtypes)
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@slow
|
@slow
|
||||||
|
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
|
||||||
def test_basic_run_with_cpu_offload(self, dtype):
|
def test_basic_run_with_cpu_offload(self, dtype):
|
||||||
launcher = get_launcher(distributed=True, use_accelerate=False)
|
launcher = get_launcher(distributed=True, use_accelerate=False)
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|||||||
Reference in New Issue
Block a user