Use Python 3.9 syntax in tests (#37343)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-08 20:12:08 +08:00
committed by GitHub
parent 0fc683d1cd
commit 1e6b546ea6
666 changed files with 265 additions and 946 deletions

View File

@@ -92,7 +92,7 @@ class TestAddNewModelLike(unittest.TestCase):
f.write(content)
def check_result(self, file_name, expected_result):
with open(file_name, "r", encoding="utf-8") as f:
with open(file_name, encoding="utf-8") as f:
result = f.read()
self.assertEqual(result, expected_result)

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import unittest
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
@@ -119,7 +119,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
self.assertEqual(schema["function"], expected_schema)
def test_nested_list(self):
def fn(x: List[List[Union[str, int]]]):
def fn(x: list[list[Union[str, int]]]):
"""
Test function
@@ -173,7 +173,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
self.assertEqual(schema["function"], expected_schema)
def test_multiple_complex_arguments(self):
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
def fn(x: list[Union[int, float]], y: Optional[Union[int, str]] = None):
"""
Test function
@@ -283,7 +283,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
self.assertEqual(schema["function"], expected_schema)
def test_tuple(self):
def fn(x: Tuple[int, str]):
def fn(x: tuple[int, str]):
"""
Test function
@@ -315,7 +315,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
self.assertEqual(schema["function"], expected_schema)
def test_single_element_tuple_fails(self):
def fn(x: Tuple[int]):
def fn(x: tuple[int]):
"""
Test function
@@ -333,7 +333,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
get_json_schema(fn)
def test_ellipsis_type_fails(self):
def fn(x: Tuple[int, ...]):
def fn(x: tuple[int, ...]):
"""
Test function
@@ -446,8 +446,8 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
def test_everything_all_at_once(self):
def fn(
x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")
) -> Tuple[int, str]:
x: str, y: Optional[list[Union[str, int]]], z: tuple[Union[str, int], str] = (42, "hello")
) -> tuple[int, str]:
"""
Test function with multiple args, and docstring args that we have to strip out.

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +16,7 @@ import logging
import os
import unittest
from pathlib import Path
from typing import List, Union
from typing import Union
import transformers
from transformers.testing_utils import require_tf, require_torch, slow
@@ -35,8 +34,8 @@ class TestCodeExamples(unittest.TestCase):
self,
directory: Path,
identifier: Union[str, None] = None,
ignore_files: Union[List[str], None] = None,
n_identifier: Union[str, List[str], None] = None,
ignore_files: Union[list[str], None] = None,
n_identifier: Union[str, list[str], None] = None,
only_modules: bool = True,
):
"""
@@ -56,7 +55,7 @@ class TestCodeExamples(unittest.TestCase):
files = [file for file in files if identifier in file]
if n_identifier is not None:
if isinstance(n_identifier, List):
if isinstance(n_identifier, list):
for n_ in n_identifier:
files = [file for file in files if n_ not in file]
else:

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -22,7 +22,7 @@ from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union, get_args, get_origin
from typing import List, Literal, Optional, Union, get_args, get_origin
import yaml
@@ -93,21 +93,21 @@ class OptionalExample:
foo: Optional[int] = None
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
baz: Optional[str] = None
ces: Optional[List[str]] = list_field(default=[])
des: Optional[List[int]] = list_field(default=[])
ces: Optional[list[str]] = list_field(default=[])
des: Optional[list[int]] = list_field(default=[])
@dataclass
class ListExample:
foo_int: List[int] = list_field(default=[])
bar_int: List[int] = list_field(default=[1, 2, 3])
foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
foo_int: list[int] = list_field(default=[])
bar_int: list[int] = list_field(default=[1, 2, 3])
foo_str: list[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
foo_float: list[float] = list_field(default=[0.1, 0.2, 0.3])
@dataclass
class RequiredExample:
required_list: List[int] = field()
required_list: list[int] = field()
required_str: str = field()
required_enum: BasicEnum = field()
@@ -435,11 +435,11 @@ class HfArgumentParserTest(unittest.TestCase):
for field in fields.values():
# First verify raw dict
if field.type in (dict, Dict):
if field.type in (dict, dict):
raw_dict_fields.append(field)
# Next check for `Union` or `Optional`
elif get_origin(field.type) == Union:
if any(arg in (dict, Dict) for arg in get_args(field.type)):
if any(arg in (dict, dict) for arg in get_args(field.type)):
optional_dict_fields.append(field)
# First check: anything in `raw_dict_fields` is very bad
@@ -455,7 +455,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = get_args(field.type)
# These should be returned as `dict`, `str`, ...
# we only care about the first two
self.assertIn(args[0], (Dict, dict))
self.assertIn(args[0], (dict, dict))
self.assertEqual(
str(args[1]),
"<class 'str'>",

View File

@@ -150,7 +150,7 @@ class GetFromCacheTests(unittest.TestCase):
_raise_exceptions_for_connection_errors=False,
)
# The name is the cached name which is not very easy to test, so instead we load the content.
config = json.loads(open(resolved_file, "r").read())
config = json.loads(open(resolved_file).read())
self.assertEqual(config["hidden_size"], 768)
def test_get_file_from_repo_local(self):

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 The Hugging Face Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -328,7 +327,7 @@ class TFModelUtilsTest(unittest.TestCase):
self.assertEqual(len(state_file), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
with open(index_file, encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
@@ -367,7 +366,7 @@ class TFModelUtilsTest(unittest.TestCase):
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
with open(index_file, encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -236,7 +235,7 @@ if is_torch_available():
except OSError:
LOG.info("Loading model %s in offline mode failed as expected", TINY_IMAGE_CLASSIF)
else:
self.fail("Loading model {} in offline mode should fail".format(TINY_IMAGE_CLASSIF))
self.fail(f"Loading model {TINY_IMAGE_CLASSIF} in offline mode should fail")
# Download model -> Huggingface Hub not concerned by our offline mode
LOG.info("Downloading %s for offline tests", TINY_IMAGE_CLASSIF)
@@ -280,7 +279,7 @@ if is_torch_available():
except OSError:
LOG.info("Loading model %s in offline mode failed as expected", TINY_IMAGE_CLASSIF)
else:
self.fail("Loading model {} in offline mode should fail".format(TINY_IMAGE_CLASSIF))
self.fail(f"Loading model {TINY_IMAGE_CLASSIF} in offline mode should fail")
LOG.info("Downloading %s for offline tests", TINY_IMAGE_CLASSIF)
hub_api = HfApi()
@@ -574,7 +573,7 @@ class ModelUtilsTest(TestCasePlus):
def remove_torch_dtype(model_path):
file = f"{model_path}/config.json"
with open(file, "r", encoding="utf-8") as f:
with open(file, encoding="utf-8") as f:
s = json.load(f)
s.pop("torch_dtype")
with open(file, "w", encoding="utf-8") as f:
@@ -745,7 +744,7 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(len(state_dict), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
with open(index_file, encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())

View File

@@ -15,7 +15,6 @@
import subprocess
import sys
import unittest
from typing import Tuple
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
from transformers.testing_utils import TestCasePlus, require_torch
@@ -195,7 +194,7 @@ print("success")
stdout, _ = self._execute_with_env(load, run, HF_HUB_OFFLINE="1")
self.assertIn("True", stdout)
def _execute_with_env(self, *commands: Tuple[str, ...], should_fail: bool = False, **env) -> Tuple[str, str]:
def _execute_with_env(self, *commands: tuple[str, ...], should_fail: bool = False, **env) -> tuple[str, str]:
"""Execute Python code with a given environment and return the stdout/stderr as strings.
If `should_fail=True`, the command is expected to fail. Otherwise, it should succeed.

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");