* fix --lr_scheduler_type choices * rewrite to fix for all enum-based cl args * cleanup * adjust test * style * Proposal that should work * Remove needless code * Fix test Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
219 lines
10 KiB
Python
219 lines
10 KiB
Python
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import dataclasses
|
|
import json
|
|
import sys
|
|
from argparse import ArgumentParser, ArgumentTypeError
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
|
|
|
|
|
|
DataClass = NewType("DataClass", Any)
|
|
DataClassType = NewType("DataClassType", Any)
|
|
|
|
|
|
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
|
def string_to_bool(v):
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
return True
|
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
return False
|
|
else:
|
|
raise ArgumentTypeError(
|
|
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
|
|
)
|
|
|
|
|
|
class HfArgumentParser(ArgumentParser):
|
|
"""
|
|
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
|
|
|
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
|
|
arguments to the parser after initialization and you'll get the output back after parsing as an additional
|
|
namespace.
|
|
"""
|
|
|
|
dataclass_types: Iterable[DataClassType]
|
|
|
|
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
|
|
"""
|
|
Args:
|
|
dataclass_types:
|
|
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
|
|
kwargs:
|
|
(Optional) Passed to `argparse.ArgumentParser()` in the regular way.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
if dataclasses.is_dataclass(dataclass_types):
|
|
dataclass_types = [dataclass_types]
|
|
self.dataclass_types = dataclass_types
|
|
for dtype in self.dataclass_types:
|
|
self._add_dataclass_arguments(dtype)
|
|
|
|
def _add_dataclass_arguments(self, dtype: DataClassType):
|
|
for field in dataclasses.fields(dtype):
|
|
if not field.init:
|
|
continue
|
|
field_name = f"--{field.name}"
|
|
kwargs = field.metadata.copy()
|
|
# field.metadata is not used at all by Data Classes,
|
|
# it is provided as a third-party extension mechanism.
|
|
if isinstance(field.type, str):
|
|
raise ImportError(
|
|
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
|
|
"which can be opted in from Python 3.7 with `from __future__ import annotations`."
|
|
"We will add compatibility when Python 3.9 is released."
|
|
)
|
|
typestring = str(field.type)
|
|
for prim_type in (int, float, str):
|
|
for collection in (List,):
|
|
if (
|
|
typestring == f"typing.Union[{collection[prim_type]}, NoneType]"
|
|
or typestring == f"typing.Optional[{collection[prim_type]}]"
|
|
):
|
|
field.type = collection[prim_type]
|
|
if (
|
|
typestring == f"typing.Union[{prim_type.__name__}, NoneType]"
|
|
or typestring == f"typing.Optional[{prim_type.__name__}]"
|
|
):
|
|
field.type = prim_type
|
|
|
|
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
|
kwargs["choices"] = [x.value for x in field.type]
|
|
kwargs["type"] = type(kwargs["choices"][0])
|
|
if field.default is not dataclasses.MISSING:
|
|
kwargs["default"] = field.default
|
|
elif field.type is bool or field.type is Optional[bool]:
|
|
if field.default is True:
|
|
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
|
|
|
|
# Hack because type=bool in argparse does not behave as we want.
|
|
kwargs["type"] = string_to_bool
|
|
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
|
# Default value is True if we have no default when of type bool.
|
|
default = True if field.default is dataclasses.MISSING else field.default
|
|
# This is the value that will get picked if we don't include --field_name in any way
|
|
kwargs["default"] = default
|
|
# This tells argparse we accept 0 or 1 value after --field_name
|
|
kwargs["nargs"] = "?"
|
|
# This is the value that will get picked if we do --field_name (without value)
|
|
kwargs["const"] = True
|
|
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
|
|
kwargs["nargs"] = "+"
|
|
kwargs["type"] = field.type.__args__[0]
|
|
assert all(
|
|
x == kwargs["type"] for x in field.type.__args__
|
|
), "{} cannot be a List of mixed types".format(field.name)
|
|
if field.default_factory is not dataclasses.MISSING:
|
|
kwargs["default"] = field.default_factory()
|
|
else:
|
|
kwargs["type"] = field.type
|
|
if field.default is not dataclasses.MISSING:
|
|
kwargs["default"] = field.default
|
|
elif field.default_factory is not dataclasses.MISSING:
|
|
kwargs["default"] = field.default_factory()
|
|
else:
|
|
kwargs["required"] = True
|
|
self.add_argument(field_name, **kwargs)
|
|
|
|
def parse_args_into_dataclasses(
|
|
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
|
|
) -> Tuple[DataClass, ...]:
|
|
"""
|
|
Parse command-line args into instances of the specified dataclass types.
|
|
|
|
This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
|
|
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
|
|
|
|
Args:
|
|
args:
|
|
List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
|
|
return_remaining_strings:
|
|
If true, also return a list of remaining argument strings.
|
|
look_for_args_file:
|
|
If true, will look for a ".args" file with the same base name as the entry point script for this
|
|
process, and will append its potential content to the command line args.
|
|
args_filename:
|
|
If not None, will uses this file instead of the ".args" file specified in the previous argument.
|
|
|
|
Returns:
|
|
Tuple consisting of:
|
|
|
|
- the dataclass instances in the same order as they were passed to the initializer.abspath
|
|
- if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
|
|
after initialization.
|
|
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
|
|
"""
|
|
if args_filename or (look_for_args_file and len(sys.argv)):
|
|
if args_filename:
|
|
args_file = Path(args_filename)
|
|
else:
|
|
args_file = Path(sys.argv[0]).with_suffix(".args")
|
|
|
|
if args_file.exists():
|
|
fargs = args_file.read_text().split()
|
|
args = fargs + args if args is not None else fargs + sys.argv[1:]
|
|
# in case of duplicate arguments the first one has precedence
|
|
# so we append rather than prepend.
|
|
namespace, remaining_args = self.parse_known_args(args=args)
|
|
outputs = []
|
|
for dtype in self.dataclass_types:
|
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
|
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
|
|
for k in keys:
|
|
delattr(namespace, k)
|
|
obj = dtype(**inputs)
|
|
outputs.append(obj)
|
|
if len(namespace.__dict__) > 0:
|
|
# additional namespace.
|
|
outputs.append(namespace)
|
|
if return_remaining_strings:
|
|
return (*outputs, remaining_args)
|
|
else:
|
|
if remaining_args:
|
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
|
|
|
|
return (*outputs,)
|
|
|
|
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
|
|
"""
|
|
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
|
dataclass types.
|
|
"""
|
|
data = json.loads(Path(json_file).read_text())
|
|
outputs = []
|
|
for dtype in self.dataclass_types:
|
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
|
inputs = {k: v for k, v in data.items() if k in keys}
|
|
obj = dtype(**inputs)
|
|
outputs.append(obj)
|
|
return (*outputs,)
|
|
|
|
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
|
|
"""
|
|
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
|
|
types.
|
|
"""
|
|
outputs = []
|
|
for dtype in self.dataclass_types:
|
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
|
inputs = {k: v for k, v in args.items() if k in keys}
|
|
obj = dtype(**inputs)
|
|
outputs.append(obj)
|
|
return (*outputs,)
|