Use Python 3.9 syntax in tests (#37343)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -92,7 +92,7 @@ class TestAddNewModelLike(unittest.TestCase):
|
||||
f.write(content)
|
||||
|
||||
def check_result(self, file_name, expected_result):
|
||||
with open(file_name, "r", encoding="utf-8") as f:
|
||||
with open(file_name, encoding="utf-8") as f:
|
||||
result = f.read()
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||
|
||||
@@ -119,7 +119,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_nested_list(self):
|
||||
def fn(x: List[List[Union[str, int]]]):
|
||||
def fn(x: list[list[Union[str, int]]]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
@@ -173,7 +173,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiple_complex_arguments(self):
|
||||
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
|
||||
def fn(x: list[Union[int, float]], y: Optional[Union[int, str]] = None):
|
||||
"""
|
||||
Test function
|
||||
|
||||
@@ -283,7 +283,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_tuple(self):
|
||||
def fn(x: Tuple[int, str]):
|
||||
def fn(x: tuple[int, str]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
@@ -315,7 +315,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_single_element_tuple_fails(self):
|
||||
def fn(x: Tuple[int]):
|
||||
def fn(x: tuple[int]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
@@ -333,7 +333,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_ellipsis_type_fails(self):
|
||||
def fn(x: Tuple[int, ...]):
|
||||
def fn(x: tuple[int, ...]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
@@ -446,8 +446,8 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
|
||||
def test_everything_all_at_once(self):
|
||||
def fn(
|
||||
x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")
|
||||
) -> Tuple[int, str]:
|
||||
x: str, y: Optional[list[Union[str, int]]], z: tuple[Union[str, int], str] = (42, "hello")
|
||||
) -> tuple[int, str]:
|
||||
"""
|
||||
Test function with multiple args, and docstring args that we have to strip out.
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -17,7 +16,7 @@ import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
import transformers
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
@@ -35,8 +34,8 @@ class TestCodeExamples(unittest.TestCase):
|
||||
self,
|
||||
directory: Path,
|
||||
identifier: Union[str, None] = None,
|
||||
ignore_files: Union[List[str], None] = None,
|
||||
n_identifier: Union[str, List[str], None] = None,
|
||||
ignore_files: Union[list[str], None] = None,
|
||||
n_identifier: Union[str, list[str], None] = None,
|
||||
only_modules: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -56,7 +55,7 @@ class TestCodeExamples(unittest.TestCase):
|
||||
files = [file for file in files if identifier in file]
|
||||
|
||||
if n_identifier is not None:
|
||||
if isinstance(n_identifier, List):
|
||||
if isinstance(n_identifier, list):
|
||||
for n_ in n_identifier:
|
||||
files = [file for file in files if n_ not in file]
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -22,7 +22,7 @@ from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union, get_args, get_origin
|
||||
from typing import List, Literal, Optional, Union, get_args, get_origin
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -93,21 +93,21 @@ class OptionalExample:
|
||||
foo: Optional[int] = None
|
||||
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
|
||||
baz: Optional[str] = None
|
||||
ces: Optional[List[str]] = list_field(default=[])
|
||||
des: Optional[List[int]] = list_field(default=[])
|
||||
ces: Optional[list[str]] = list_field(default=[])
|
||||
des: Optional[list[int]] = list_field(default=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListExample:
|
||||
foo_int: List[int] = list_field(default=[])
|
||||
bar_int: List[int] = list_field(default=[1, 2, 3])
|
||||
foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
foo_int: list[int] = list_field(default=[])
|
||||
bar_int: list[int] = list_field(default=[1, 2, 3])
|
||||
foo_str: list[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_float: list[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequiredExample:
|
||||
required_list: List[int] = field()
|
||||
required_list: list[int] = field()
|
||||
required_str: str = field()
|
||||
required_enum: BasicEnum = field()
|
||||
|
||||
@@ -435,11 +435,11 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
for field in fields.values():
|
||||
# First verify raw dict
|
||||
if field.type in (dict, Dict):
|
||||
if field.type in (dict, dict):
|
||||
raw_dict_fields.append(field)
|
||||
# Next check for `Union` or `Optional`
|
||||
elif get_origin(field.type) == Union:
|
||||
if any(arg in (dict, Dict) for arg in get_args(field.type)):
|
||||
if any(arg in (dict, dict) for arg in get_args(field.type)):
|
||||
optional_dict_fields.append(field)
|
||||
|
||||
# First check: anything in `raw_dict_fields` is very bad
|
||||
@@ -455,7 +455,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = get_args(field.type)
|
||||
# These should be returned as `dict`, `str`, ...
|
||||
# we only care about the first two
|
||||
self.assertIn(args[0], (Dict, dict))
|
||||
self.assertIn(args[0], (dict, dict))
|
||||
self.assertEqual(
|
||||
str(args[1]),
|
||||
"<class 'str'>",
|
||||
|
||||
@@ -150,7 +150,7 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
# The name is the cached name which is not very easy to test, so instead we load the content.
|
||||
config = json.loads(open(resolved_file, "r").read())
|
||||
config = json.loads(open(resolved_file).read())
|
||||
self.assertEqual(config["hidden_size"], 768)
|
||||
|
||||
def test_get_file_from_repo_local(self):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Hugging Face Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -328,7 +327,7 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
self.assertEqual(len(state_file), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
with open(index_file, encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
@@ -367,7 +366,7 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
with open(index_file, encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -236,7 +235,7 @@ if is_torch_available():
|
||||
except OSError:
|
||||
LOG.info("Loading model %s in offline mode failed as expected", TINY_IMAGE_CLASSIF)
|
||||
else:
|
||||
self.fail("Loading model {} in offline mode should fail".format(TINY_IMAGE_CLASSIF))
|
||||
self.fail(f"Loading model {TINY_IMAGE_CLASSIF} in offline mode should fail")
|
||||
|
||||
# Download model -> Huggingface Hub not concerned by our offline mode
|
||||
LOG.info("Downloading %s for offline tests", TINY_IMAGE_CLASSIF)
|
||||
@@ -280,7 +279,7 @@ if is_torch_available():
|
||||
except OSError:
|
||||
LOG.info("Loading model %s in offline mode failed as expected", TINY_IMAGE_CLASSIF)
|
||||
else:
|
||||
self.fail("Loading model {} in offline mode should fail".format(TINY_IMAGE_CLASSIF))
|
||||
self.fail(f"Loading model {TINY_IMAGE_CLASSIF} in offline mode should fail")
|
||||
|
||||
LOG.info("Downloading %s for offline tests", TINY_IMAGE_CLASSIF)
|
||||
hub_api = HfApi()
|
||||
@@ -574,7 +573,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
def remove_torch_dtype(model_path):
|
||||
file = f"{model_path}/config.json"
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
with open(file, encoding="utf-8") as f:
|
||||
s = json.load(f)
|
||||
s.pop("torch_dtype")
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
@@ -745,7 +744,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(len(state_dict), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
with open(index_file, encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
|
||||
from transformers.testing_utils import TestCasePlus, require_torch
|
||||
@@ -195,7 +194,7 @@ print("success")
|
||||
stdout, _ = self._execute_with_env(load, run, HF_HUB_OFFLINE="1")
|
||||
self.assertIn("True", stdout)
|
||||
|
||||
def _execute_with_env(self, *commands: Tuple[str, ...], should_fail: bool = False, **env) -> Tuple[str, str]:
|
||||
def _execute_with_env(self, *commands: tuple[str, ...], should_fail: bool = False, **env) -> tuple[str, str]:
|
||||
"""Execute Python code with a given environment and return the stdout/stderr as strings.
|
||||
|
||||
If `should_fail=True`, the command is expected to fail. Otherwise, it should succeed.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
Reference in New Issue
Block a user