byebye torch 2.0 (#37277)
* bump Torch 2.1 with broken compatibility `torch.compile`
* dep table
* remove usage of is_torch_greater_or_equal_than_2_1
* remove usage of is_torch_greater_or_equal_than_2_1
* remove if is_torch_greater_or_equal("2.1.0")
* remove torch >= "2.1.0"
* deal with 2.0.0
* PyTorch 2.0+ --> PyTorch 2.1+
* ruff 1
* difficult ruff
* address comment
* address comment
---------
Co-authored-by: Jirka B <j.borovec+github@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -485,20 +485,15 @@ str_to_torch_dtype = {
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
if is_torch_greater_or_equal("2.1.0"):
|
||||
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
|
||||
|
||||
if is_torch_greater_or_equal("2.3.0"):
|
||||
str_to_torch_dtype["U16"] = torch.uint16
|
||||
str_to_torch_dtype["U32"] = torch.uint32
|
||||
str_to_torch_dtype["U64"] = torch.uint64
|
||||
|
||||
if is_torch_greater_or_equal("2.1.0"):
|
||||
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
|
||||
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
@@ -546,12 +541,7 @@ def load_state_dict(
|
||||
map_location = "cpu"
|
||||
extra_args = {}
|
||||
# mmap can only be used with files serialized with zipfile-based format.
|
||||
if (
|
||||
isinstance(checkpoint_file, str)
|
||||
and map_location != "meta"
|
||||
and version.parse(torch.__version__) >= version.parse("2.1.0")
|
||||
and is_zipfile(checkpoint_file)
|
||||
):
|
||||
if isinstance(checkpoint_file, str) and map_location != "meta" and is_zipfile(checkpoint_file):
|
||||
extra_args = {"mmap": True}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
|
||||
Reference in New Issue
Block a user