No more Tuple, List, Dict (#38797)
* No more Tuple, List, Dict * make fixup * More style fixes * Docstring fixes with regex replacement * Trigger tests * Redo fixes after rebase * Fix copies * [test all] * update * [test all] * update * [test all] * make style after rebase * Patch the hf_argparser test * Patch the hf_argparser test * style fixes * style fixes * style fixes * Fix docstrings in Cohere test * [test all] --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Union, get_args, get_origin
|
||||
from typing import Literal, Optional, Union, get_args, get_origin
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -121,7 +121,7 @@ class StringLiteralAnnotationExample:
|
||||
required_enum: "BasicEnum" = field()
|
||||
opt: "Optional[bool]" = None
|
||||
baz: "str" = field(default="toto", metadata={"help": "help message"})
|
||||
foo_str: "List[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
|
||||
|
||||
if is_python_no_less_than_3_10:
|
||||
@@ -435,11 +435,11 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
for field in fields.values():
|
||||
# First verify raw dict
|
||||
if field.type in (dict, dict):
|
||||
if field.type is dict:
|
||||
raw_dict_fields.append(field)
|
||||
# Next check for `Union` or `Optional`
|
||||
elif get_origin(field.type) == Union:
|
||||
if any(arg in (dict, dict) for arg in get_args(field.type)):
|
||||
if any(arg is dict for arg in get_args(field.type)):
|
||||
optional_dict_fields.append(field)
|
||||
|
||||
# First check: anything in `raw_dict_fields` is very bad
|
||||
@@ -455,12 +455,15 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = get_args(field.type)
|
||||
# These should be returned as `dict`, `str`, ...
|
||||
# we only care about the first two
|
||||
self.assertIn(args[0], (dict, dict))
|
||||
self.assertEqual(
|
||||
str(args[1]),
|
||||
"<class 'str'>",
|
||||
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
|
||||
"but `str` not found. Please fix this.",
|
||||
self.assertIn(
|
||||
dict,
|
||||
args,
|
||||
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `dict` not found. Please fix this.",
|
||||
)
|
||||
self.assertIn(
|
||||
str,
|
||||
args,
|
||||
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `str` not found. Please fix this.",
|
||||
)
|
||||
|
||||
# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
|
||||
|
||||
Reference in New Issue
Block a user