Convert _VALID_DICT_FIELDS to class attribute for shared dict parsing in subclasses (#36736)

* make _VALID_DICT_FIELDS as a class attribute

* fix test case about TrainingArguments
This commit is contained in:
Qizhi Chen
2025-04-01 18:29:12 +08:00
committed by GitHub
parent ae34bd75fd
commit fac70ff3c0
2 changed files with 17 additions and 18 deletions

View File

@@ -29,7 +29,6 @@ import yaml
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import make_choice_type_function, string_to_bool
from transformers.testing_utils import require_torch
from transformers.training_args import _VALID_DICT_FIELDS
# Since Python 3.10, we can use the builtin `|` operator for Union types
@@ -412,7 +411,8 @@ class HfArgumentParserTest(unittest.TestCase):
args = BasicExample(**args_dict_for_yaml)
self.assertEqual(parsed_args, args)
def test_integration_training_args(self):
def test_z_integration_training_args(self):
# make sure that this test executes last in the test suite
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
@@ -424,7 +424,7 @@ class HfArgumentParserTest(unittest.TestCase):
If this fails, a type annotation change is
needed on a new input
"""
base_list = _VALID_DICT_FIELDS.copy()
base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
args = TrainingArguments
# First find any annotations that contain `dict`
@@ -468,7 +468,7 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertIn(
field.name,
base_list,
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
)
@require_torch