No more Tuple, List, Dict (#38797)
* No more Tuple, List, Dict * make fixup * More style fixes * Docstring fixes with regex replacement * Trigger tests * Redo fixes after rebase * Fix copies * [test all] * update * [test all] * update * [test all] * make style after rebase * Patch the hf_argparser test * Patch the hf_argparser test * style fixes * style fixes * style fixes * Fix docstrings in Cohere test * [test all] --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -309,7 +309,7 @@ class MarkupLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
# ValueError: Nodes must be of type `List[str]` (single pretokenized example), or `List[List[str]]`
|
||||
# ValueError: Nodes must be of type `list[str]` (single pretokenized example), or `list[list[str]]`
|
||||
# (batch of pretokenized examples).
|
||||
return True
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -52,7 +51,7 @@ class TimesFmModelTester:
|
||||
num_heads: int = 2,
|
||||
tolerance: float = 1e-6,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
quantiles: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
pad_val: float = 1123581321.0,
|
||||
use_positional_embedding: bool = True,
|
||||
initializer_factor: float = 0.0,
|
||||
|
||||
@@ -248,7 +248,7 @@ class UnivNetFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test np.ndarray vs List[np.ndarray]
|
||||
# Test np.ndarray vs list[np.ndarray]
|
||||
encoded_sequences_1 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
||||
encoded_sequences_2 = feature_extractor([np_speech_inputs], return_tensors="np").input_features
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
|
||||
@@ -280,13 +280,13 @@ class PipelineTesterMixin:
|
||||
A model repository id on the Hub.
|
||||
model_architecture (`type`):
|
||||
A subclass of `PretrainedModel` or `PretrainedModel`.
|
||||
tokenizer_names (`List[str]`):
|
||||
tokenizer_names (`list[str]`):
|
||||
A list of names of a subclasses of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`.
|
||||
image_processor_names (`List[str]`):
|
||||
image_processor_names (`list[str]`):
|
||||
A list of names of subclasses of `BaseImageProcessor`.
|
||||
feature_extractor_names (`List[str]`):
|
||||
feature_extractor_names (`list[str]`):
|
||||
A list of names of subclasses of `FeatureExtractionMixin`.
|
||||
processor_names (`List[str]`):
|
||||
processor_names (`list[str]`):
|
||||
A list of names of subclasses of `ProcessorMixin`.
|
||||
commit (`str`):
|
||||
The commit hash of the model repository on the Hub.
|
||||
|
||||
@@ -123,7 +123,7 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
batch_size = self.feat_extract_tester.batch_size
|
||||
feature_size = self.feat_extract_tester.feature_size
|
||||
|
||||
# test padding for List[int] + numpy
|
||||
# test padding for list[int] + numpy
|
||||
input_1 = feat_extract.pad(processed_features, padding=False)
|
||||
input_1 = input_1[input_name]
|
||||
|
||||
@@ -157,7 +157,7 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
if feature_size > 1:
|
||||
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
|
||||
|
||||
# test padding for `pad_to_multiple_of` for List[int] + numpy
|
||||
# test padding for `pad_to_multiple_of` for list[int] + numpy
|
||||
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)
|
||||
input_6 = input_6[input_name]
|
||||
|
||||
@@ -319,7 +319,7 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
with self.assertRaises(ValueError):
|
||||
feat_extract.pad(processed_features, padding="max_length", truncation=True)[input_name]
|
||||
|
||||
# test truncation for `pad_to_multiple_of` for List[int] + numpy
|
||||
# test truncation for `pad_to_multiple_of` for list[int] + numpy
|
||||
pad_to_multiple_of = 12
|
||||
input_8 = feat_extract.pad(
|
||||
processed_features,
|
||||
|
||||
@@ -22,7 +22,7 @@ from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Union, get_args, get_origin
|
||||
from typing import Literal, Optional, Union, get_args, get_origin
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -121,7 +121,7 @@ class StringLiteralAnnotationExample:
|
||||
required_enum: "BasicEnum" = field()
|
||||
opt: "Optional[bool]" = None
|
||||
baz: "str" = field(default="toto", metadata={"help": "help message"})
|
||||
foo_str: "List[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
|
||||
|
||||
if is_python_no_less_than_3_10:
|
||||
@@ -435,11 +435,11 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
for field in fields.values():
|
||||
# First verify raw dict
|
||||
if field.type in (dict, dict):
|
||||
if field.type is dict:
|
||||
raw_dict_fields.append(field)
|
||||
# Next check for `Union` or `Optional`
|
||||
elif get_origin(field.type) == Union:
|
||||
if any(arg in (dict, dict) for arg in get_args(field.type)):
|
||||
if any(arg is dict for arg in get_args(field.type)):
|
||||
optional_dict_fields.append(field)
|
||||
|
||||
# First check: anything in `raw_dict_fields` is very bad
|
||||
@@ -455,12 +455,15 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = get_args(field.type)
|
||||
# These should be returned as `dict`, `str`, ...
|
||||
# we only care about the first two
|
||||
self.assertIn(args[0], (dict, dict))
|
||||
self.assertEqual(
|
||||
str(args[1]),
|
||||
"<class 'str'>",
|
||||
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
|
||||
"but `str` not found. Please fix this.",
|
||||
self.assertIn(
|
||||
dict,
|
||||
args,
|
||||
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `dict` not found. Please fix this.",
|
||||
)
|
||||
self.assertIn(
|
||||
str,
|
||||
args,
|
||||
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `str` not found. Please fix this.",
|
||||
)
|
||||
|
||||
# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
|
||||
|
||||
Reference in New Issue
Block a user