Convert _VALID_DICT_FIELDS to class attribute for shared dict parsing in subclasses (#36736)

* make _VALID_DICT_FIELDS as a class attribute

* fix test case about TrainingArguments
This commit is contained in:
Qizhi Chen
2025-04-01 18:29:12 +08:00
committed by GitHub
parent ae34bd75fd
commit fac70ff3c0
2 changed files with 17 additions and 18 deletions

View File

@@ -188,19 +188,6 @@ class OptimizerNames(ExplicitEnum):
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
_VALID_DICT_FIELDS = [
"accelerator_config",
"fsdp_config",
"deepspeed",
"gradient_checkpointing_kwargs",
"lr_scheduler_kwargs",
]
def _convert_str_dict(passed_value: dict):
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
for key, value in passed_value.items():
@@ -814,6 +801,18 @@ class TrainingArguments:
https://github.com/huggingface/transformers/issues/34242
"""
# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
_VALID_DICT_FIELDS = [
"accelerator_config",
"fsdp_config",
"deepspeed",
"gradient_checkpointing_kwargs",
"lr_scheduler_kwargs",
]
framework = "pt"
output_dir: Optional[str] = field(
default=None,
@@ -1561,7 +1560,7 @@ class TrainingArguments:
)
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
for field in self._VALID_DICT_FIELDS:
passed_value = getattr(self, field)
# We only want to do this if the str starts with a bracket to indicate a `dict`
# else its likely a filename if supported

View File

@@ -29,7 +29,6 @@ import yaml
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import make_choice_type_function, string_to_bool
from transformers.testing_utils import require_torch
from transformers.training_args import _VALID_DICT_FIELDS
# Since Python 3.10, we can use the builtin `|` operator for Union types
@@ -412,7 +411,8 @@ class HfArgumentParserTest(unittest.TestCase):
args = BasicExample(**args_dict_for_yaml)
self.assertEqual(parsed_args, args)
def test_integration_training_args(self):
def test_z_integration_training_args(self):
# make sure that this test executes last in the test suite
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
@@ -424,7 +424,7 @@ class HfArgumentParserTest(unittest.TestCase):
If this fails, a type annotation change is
needed on a new input
"""
base_list = _VALID_DICT_FIELDS.copy()
base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
args = TrainingArguments
# First find any annotations that contain `dict`
@@ -468,7 +468,7 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertIn(
field.name,
base_list,
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
)
@require_torch