Convert _VALID_DICT_FIELDS to class attribute for shared dict parsing in subclasses (#36736)
* make _VALID_DICT_FIELDS as a class attribute * fix test case about TrainingArguments
This commit is contained in:
@@ -188,19 +188,6 @@ class OptimizerNames(ExplicitEnum):
|
||||
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
|
||||
|
||||
|
||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||
# We need to track what fields those can be. Each time a new arg
|
||||
# has a dict type, it must be added to this list.
|
||||
# Important: These should be typed with Optional[Union[dict,str,...]]
|
||||
_VALID_DICT_FIELDS = [
|
||||
"accelerator_config",
|
||||
"fsdp_config",
|
||||
"deepspeed",
|
||||
"gradient_checkpointing_kwargs",
|
||||
"lr_scheduler_kwargs",
|
||||
]
|
||||
|
||||
|
||||
def _convert_str_dict(passed_value: dict):
|
||||
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
|
||||
for key, value in passed_value.items():
|
||||
@@ -814,6 +801,18 @@ class TrainingArguments:
|
||||
https://github.com/huggingface/transformers/issues/34242
|
||||
"""
|
||||
|
||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||
# We need to track what fields those can be. Each time a new arg
|
||||
# has a dict type, it must be added to this list.
|
||||
# Important: These should be typed with Optional[Union[dict,str,...]]
|
||||
_VALID_DICT_FIELDS = [
|
||||
"accelerator_config",
|
||||
"fsdp_config",
|
||||
"deepspeed",
|
||||
"gradient_checkpointing_kwargs",
|
||||
"lr_scheduler_kwargs",
|
||||
]
|
||||
|
||||
framework = "pt"
|
||||
output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -1561,7 +1560,7 @@ class TrainingArguments:
|
||||
)
|
||||
|
||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||
for field in _VALID_DICT_FIELDS:
|
||||
for field in self._VALID_DICT_FIELDS:
|
||||
passed_value = getattr(self, field)
|
||||
# We only want to do this if the str starts with a bracket to indicate a `dict`
|
||||
# else its likely a filename if supported
|
||||
|
||||
@@ -29,7 +29,6 @@ import yaml
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.training_args import _VALID_DICT_FIELDS
|
||||
|
||||
|
||||
# Since Python 3.10, we can use the builtin `|` operator for Union types
|
||||
@@ -412,7 +411,8 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = BasicExample(**args_dict_for_yaml)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_integration_training_args(self):
|
||||
def test_z_integration_training_args(self):
|
||||
# make sure that this test executes last in the test suite
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
||||
@@ -424,7 +424,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
If this fails, a type annotation change is
|
||||
needed on a new input
|
||||
"""
|
||||
base_list = _VALID_DICT_FIELDS.copy()
|
||||
base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
|
||||
args = TrainingArguments
|
||||
|
||||
# First find any annotations that contain `dict`
|
||||
@@ -468,7 +468,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
self.assertIn(
|
||||
field.name,
|
||||
base_list,
|
||||
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
|
||||
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user