Allow --arg Value for booleans in HfArgumentParser (#9823)

* Allow --arg Value for booleans in HfArgumentParser

* Update last test

* Better error message
This commit is contained in:
Sylvain Gugger
2021-01-27 09:31:42 -05:00
committed by GitHub
parent 35d55b7b84
commit 893120facc
2 changed files with 45 additions and 9 deletions

View File

@@ -15,7 +15,7 @@
import dataclasses
import json
import sys
from argparse import ArgumentParser
from argparse import ArgumentParser, ArgumentTypeError
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
@@ -25,6 +25,20 @@ DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def string_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
)
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
@@ -85,11 +99,20 @@ class HfArgumentParser(ArgumentParser):
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]:
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no_{field.name}"
kwargs["dest"] = field.name
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
# Hack because type=bool in argparse does not behave as we want.
kwargs["type"] = string_to_bool
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
# Default value is True if we have no default when of type bool.
default = True if field.default is dataclasses.MISSING else field.default
# This is the value that will get picked if we don't include --field_name in any way
kwargs["default"] = default
# This tells argparse we accept 0 or 1 value after --field_name
kwargs["nargs"] = "?"
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]