Allow for str versions of dicts based on typing (#30227)
* Bookmark, initial impelemtation. Need to test * Clean * Working fully, woop woop * I think working version now, testing * Fin! * rm cast, could keep None * Fix typing issue * rm typehint * Add test * Add tests and make more rigid
This commit is contained in:
@@ -22,12 +22,14 @@ from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import Dict, List, Literal, Optional, Union, get_args, get_origin
|
||||
|
||||
import yaml
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.training_args import _VALID_DICT_FIELDS
|
||||
|
||||
|
||||
# Since Python 3.10, we can use the builtin `|` operator for Union types
|
||||
@@ -405,3 +407,68 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
def test_integration_training_args(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
||||
def test_valid_dict_annotation(self):
|
||||
"""
|
||||
Tests to make sure that `dict` based annotations
|
||||
are correctly made in the `TrainingArguments`.
|
||||
|
||||
If this fails, a type annotation change is
|
||||
needed on a new input
|
||||
"""
|
||||
base_list = _VALID_DICT_FIELDS.copy()
|
||||
args = TrainingArguments
|
||||
|
||||
# First find any annotations that contain `dict`
|
||||
fields = args.__dataclass_fields__
|
||||
|
||||
raw_dict_fields = []
|
||||
optional_dict_fields = []
|
||||
|
||||
for field in fields.values():
|
||||
# First verify raw dict
|
||||
if field.type in (dict, Dict):
|
||||
raw_dict_fields.append(field)
|
||||
# Next check for `Union` or `Optional`
|
||||
elif get_origin(field.type) == Union:
|
||||
if any(arg in (dict, Dict) for arg in get_args(field.type)):
|
||||
optional_dict_fields.append(field)
|
||||
|
||||
# First check: anything in `raw_dict_fields` is very bad
|
||||
self.assertEqual(
|
||||
len(raw_dict_fields),
|
||||
0,
|
||||
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
|
||||
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
|
||||
)
|
||||
|
||||
# Next check raw annotations
|
||||
for field in optional_dict_fields:
|
||||
args = get_args(field.type)
|
||||
# These should be returned as `dict`, `str`, ...
|
||||
# we only care about the first two
|
||||
self.assertIn(args[0], (Dict, dict))
|
||||
self.assertEqual(
|
||||
str(args[1]),
|
||||
"<class 'str'>",
|
||||
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
|
||||
"but `str` not found. Please fix this.",
|
||||
)
|
||||
|
||||
# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
|
||||
for field in optional_dict_fields:
|
||||
self.assertIn(
|
||||
field.name,
|
||||
base_list,
|
||||
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_valid_dict_input_parsing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
|
||||
)
|
||||
self.assertEqual(args.accelerator_config.split_batches, True)
|
||||
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
|
||||
|
||||
Reference in New Issue
Block a user