Convert _VALID_DICT_FIELDS to class attribute for shared dict parsing in subclasses (#36736)
* make _VALID_DICT_FIELDS as a class attribute * fix test case about TrainingArguments
This commit is contained in:
@@ -29,7 +29,6 @@ import yaml
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.training_args import _VALID_DICT_FIELDS
|
||||
|
||||
|
||||
# Since Python 3.10, we can use the builtin `|` operator for Union types
|
||||
@@ -412,7 +411,8 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = BasicExample(**args_dict_for_yaml)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_integration_training_args(self):
|
||||
def test_z_integration_training_args(self):
|
||||
# make sure that this test executes last in the test suite
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
||||
@@ -424,7 +424,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
If this fails, a type annotation change is
|
||||
needed on a new input
|
||||
"""
|
||||
base_list = _VALID_DICT_FIELDS.copy()
|
||||
base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
|
||||
args = TrainingArguments
|
||||
|
||||
# First find any annotations that contain `dict`
|
||||
@@ -468,7 +468,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
self.assertIn(
|
||||
field.name,
|
||||
base_list,
|
||||
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
|
||||
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user