byebye torch 2.0 (#37277)

* bump Torch 2.1 with broken compatibility `torch.compile`

* dep table

* remove usage of is_torch_greater_or_equal_than_2_1

* remove usage of is_torch_greater_or_equal_than_2_1

* remove if is_torch_greater_or_equal("2.1.0")

* remove torch >= "2.1.0"

* deal with 2.0.0

* PyTorch 2.0+ --> PyTorch 2.1+

* ruff 1

* difficult ruff

* address comment

* address comment

---------

Co-authored-by: Jirka B <j.borovec+github@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-04-07 15:19:47 +02:00
committed by GitHub
parent 99f9f1042f
commit e7ad077012
28 changed files with 38 additions and 113 deletions

View File

@@ -47,10 +47,7 @@ from transformers.utils import (
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
from transformers.trainer import FSDP_MODEL_NAME
else:
is_torch_greater_or_equal_than_2_1 = False
# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"
@@ -260,7 +257,6 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
@require_torch_multi_accelerator
@run_first
@slow
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
def test_basic_run_with_cpu_offload(self, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
output_dir = self.get_auto_remove_tmp_dir()