Use Python 3.9 syntax in tests (#37343)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union, get_args, get_origin
|
||||
from typing import List, Literal, Optional, Union, get_args, get_origin
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -93,21 +93,21 @@ class OptionalExample:
|
||||
foo: Optional[int] = None
|
||||
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
|
||||
baz: Optional[str] = None
|
||||
ces: Optional[List[str]] = list_field(default=[])
|
||||
des: Optional[List[int]] = list_field(default=[])
|
||||
ces: Optional[list[str]] = list_field(default=[])
|
||||
des: Optional[list[int]] = list_field(default=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListExample:
|
||||
foo_int: List[int] = list_field(default=[])
|
||||
bar_int: List[int] = list_field(default=[1, 2, 3])
|
||||
foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
foo_int: list[int] = list_field(default=[])
|
||||
bar_int: list[int] = list_field(default=[1, 2, 3])
|
||||
foo_str: list[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_float: list[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequiredExample:
|
||||
required_list: List[int] = field()
|
||||
required_list: list[int] = field()
|
||||
required_str: str = field()
|
||||
required_enum: BasicEnum = field()
|
||||
|
||||
@@ -435,11 +435,11 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
for field in fields.values():
|
||||
# First verify raw dict
|
||||
if field.type in (dict, Dict):
|
||||
if field.type in (dict, dict):
|
||||
raw_dict_fields.append(field)
|
||||
# Next check for `Union` or `Optional`
|
||||
elif get_origin(field.type) == Union:
|
||||
if any(arg in (dict, Dict) for arg in get_args(field.type)):
|
||||
if any(arg in (dict, dict) for arg in get_args(field.type)):
|
||||
optional_dict_fields.append(field)
|
||||
|
||||
# First check: anything in `raw_dict_fields` is very bad
|
||||
@@ -455,7 +455,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = get_args(field.type)
|
||||
# These should be returned as `dict`, `str`, ...
|
||||
# we only care about the first two
|
||||
self.assertIn(args[0], (Dict, dict))
|
||||
self.assertIn(args[0], (dict, dict))
|
||||
self.assertEqual(
|
||||
str(args[1]),
|
||||
"<class 'str'>",
|
||||
|
||||
Reference in New Issue
Block a user