byebye torch 2.0 (#37277)

* bump Torch 2.1 with broken compatibility `torch.compile`

* dep table

* remove usage of is_torch_greater_or_equal_than_2_1

* remove usage of is_torch_greater_or_equal_than_2_1

* remove if is_torch_greater_or_equal("2.1.0")

* remove torch >= "2.1.0"

* deal with 2.0.0

* PyTorch 2.0+ --> PyTorch 2.1+

* ruff 1

* difficult ruff

* address comment

* address comment

---------

Co-authored-by: Jirka B <j.borovec+github@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-04-07 15:19:47 +02:00
committed by GitHub
parent 99f9f1042f
commit e7ad077012
28 changed files with 38 additions and 113 deletions

View File

@@ -485,20 +485,15 @@ str_to_torch_dtype = {
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
if is_torch_greater_or_equal("2.1.0"):
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
if is_torch_greater_or_equal("2.3.0"):
str_to_torch_dtype["U16"] = torch.uint16
str_to_torch_dtype["U32"] = torch.uint32
str_to_torch_dtype["U64"] = torch.uint64
if is_torch_greater_or_equal("2.1.0"):
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
@@ -546,12 +541,7 @@ def load_state_dict(
map_location = "cpu"
extra_args = {}
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
if isinstance(checkpoint_file, str) and map_location != "meta" and is_zipfile(checkpoint_file):
extra_args = {"mmap": True}
return torch.load(
checkpoint_file,