Allow for str versions of dicts based on typing (#30227)
* Bookmark, initial impelemtation. Need to test * Clean * Working fully, woop woop * I think working version now, testing * Fin! * rm cast, could keep None * Fix typing issue * rm typehint * Add test * Add tests and make more rigid
This commit is contained in:
@@ -173,6 +173,37 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||||
|
|
||||||
|
|
||||||
|
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||||
|
# We need to track what fields those can be. Each time a new arg
|
||||||
|
# has a dict type, it must be added to this list.
|
||||||
|
# Important: These should be typed with Optional[Union[dict,str,...]]
|
||||||
|
_VALID_DICT_FIELDS = [
|
||||||
|
"accelerator_config",
|
||||||
|
"fsdp_config",
|
||||||
|
"deepspeed",
|
||||||
|
"gradient_checkpointing_kwargs",
|
||||||
|
"lr_scheduler_kwargs",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_str_dict(passed_value: dict):
|
||||||
|
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
|
||||||
|
for key, value in passed_value.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
passed_value[key] = _convert_str_dict(value)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
# First check for bool and convert
|
||||||
|
if value.lower() in ("true", "false"):
|
||||||
|
passed_value[key] = value.lower() == "true"
|
||||||
|
# Check for digit
|
||||||
|
elif value.isdigit():
|
||||||
|
passed_value[key] = int(value)
|
||||||
|
elif value.replace(".", "", 1).isdigit():
|
||||||
|
passed_value[key] = float(value)
|
||||||
|
|
||||||
|
return passed_value
|
||||||
|
|
||||||
|
|
||||||
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments:
|
class TrainingArguments:
|
||||||
@@ -803,11 +834,11 @@ class TrainingArguments:
|
|||||||
default="linear",
|
default="linear",
|
||||||
metadata={"help": "The scheduler type to use."},
|
metadata={"help": "The scheduler type to use."},
|
||||||
)
|
)
|
||||||
lr_scheduler_kwargs: Optional[Dict] = field(
|
lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
|
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -1118,7 +1149,6 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Do not touch this type annotation or it will stop working in CLI
|
|
||||||
fsdp_config: Optional[Union[dict, str]] = field(
|
fsdp_config: Optional[Union[dict, str]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -1137,8 +1167,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Do not touch this type annotation or it will stop working in CLI
|
accelerator_config: Optional[Union[dict, str]] = field(
|
||||||
accelerator_config: Optional[str] = field(
|
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -1147,8 +1176,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Do not touch this type annotation or it will stop working in CLI
|
deepspeed: Optional[Union[dict, str]] = field(
|
||||||
deepspeed: Optional[str] = field(
|
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -1252,7 +1280,7 @@ class TrainingArguments:
|
|||||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[dict] = field(
|
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
|
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
|
||||||
@@ -1380,6 +1408,17 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||||
|
for field in _VALID_DICT_FIELDS:
|
||||||
|
passed_value = getattr(self, field)
|
||||||
|
# We only want to do this if the str starts with a bracket to indiciate a `dict`
|
||||||
|
# else its likely a filename if supported
|
||||||
|
if isinstance(passed_value, str) and passed_value.startswith("{"):
|
||||||
|
loaded_dict = json.loads(passed_value)
|
||||||
|
# Convert str values to types if applicable
|
||||||
|
loaded_dict = _convert_str_dict(loaded_dict)
|
||||||
|
setattr(self, field, loaded_dict)
|
||||||
|
|
||||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||||
# in the current directory instead of the actual home
|
# in the current directory instead of the actual home
|
||||||
# see https://github.com/huggingface/transformers/issues/10628
|
# see https://github.com/huggingface/transformers/issues/10628
|
||||||
|
|||||||
@@ -22,12 +22,14 @@ from argparse import Namespace
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Optional
|
from typing import Dict, List, Literal, Optional, Union, get_args, get_origin
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from transformers import HfArgumentParser, TrainingArguments
|
from transformers import HfArgumentParser, TrainingArguments
|
||||||
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
from transformers.training_args import _VALID_DICT_FIELDS
|
||||||
|
|
||||||
|
|
||||||
# Since Python 3.10, we can use the builtin `|` operator for Union types
|
# Since Python 3.10, we can use the builtin `|` operator for Union types
|
||||||
@@ -405,3 +407,68 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
def test_integration_training_args(self):
|
def test_integration_training_args(self):
|
||||||
parser = HfArgumentParser(TrainingArguments)
|
parser = HfArgumentParser(TrainingArguments)
|
||||||
self.assertIsNotNone(parser)
|
self.assertIsNotNone(parser)
|
||||||
|
|
||||||
|
def test_valid_dict_annotation(self):
|
||||||
|
"""
|
||||||
|
Tests to make sure that `dict` based annotations
|
||||||
|
are correctly made in the `TrainingArguments`.
|
||||||
|
|
||||||
|
If this fails, a type annotation change is
|
||||||
|
needed on a new input
|
||||||
|
"""
|
||||||
|
base_list = _VALID_DICT_FIELDS.copy()
|
||||||
|
args = TrainingArguments
|
||||||
|
|
||||||
|
# First find any annotations that contain `dict`
|
||||||
|
fields = args.__dataclass_fields__
|
||||||
|
|
||||||
|
raw_dict_fields = []
|
||||||
|
optional_dict_fields = []
|
||||||
|
|
||||||
|
for field in fields.values():
|
||||||
|
# First verify raw dict
|
||||||
|
if field.type in (dict, Dict):
|
||||||
|
raw_dict_fields.append(field)
|
||||||
|
# Next check for `Union` or `Optional`
|
||||||
|
elif get_origin(field.type) == Union:
|
||||||
|
if any(arg in (dict, Dict) for arg in get_args(field.type)):
|
||||||
|
optional_dict_fields.append(field)
|
||||||
|
|
||||||
|
# First check: anything in `raw_dict_fields` is very bad
|
||||||
|
self.assertEqual(
|
||||||
|
len(raw_dict_fields),
|
||||||
|
0,
|
||||||
|
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
|
||||||
|
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Next check raw annotations
|
||||||
|
for field in optional_dict_fields:
|
||||||
|
args = get_args(field.type)
|
||||||
|
# These should be returned as `dict`, `str`, ...
|
||||||
|
# we only care about the first two
|
||||||
|
self.assertIn(args[0], (Dict, dict))
|
||||||
|
self.assertEqual(
|
||||||
|
str(args[1]),
|
||||||
|
"<class 'str'>",
|
||||||
|
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
|
||||||
|
"but `str` not found. Please fix this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
|
||||||
|
for field in optional_dict_fields:
|
||||||
|
self.assertIn(
|
||||||
|
field.name,
|
||||||
|
base_list,
|
||||||
|
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_valid_dict_input_parsing(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
|
||||||
|
)
|
||||||
|
self.assertEqual(args.accelerator_config.split_batches, True)
|
||||||
|
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user