Allow --arg Value for booleans in HfArgumentParser (#9823)
* Allow --arg Value for booleans in HfArgumentParser * Update last test * Better error message
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, ArgumentTypeError
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
|
||||
@@ -25,6 +25,20 @@ DataClass = NewType("DataClass", Any)
|
||||
DataClassType = NewType("DataClassType", Any)
|
||||
|
||||
|
||||
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
||||
def string_to_bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise ArgumentTypeError(
|
||||
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
|
||||
)
|
||||
|
||||
|
||||
class HfArgumentParser(ArgumentParser):
|
||||
"""
|
||||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
||||
@@ -85,11 +99,20 @@ class HfArgumentParser(ArgumentParser):
|
||||
if field.default is not dataclasses.MISSING:
|
||||
kwargs["default"] = field.default
|
||||
elif field.type is bool or field.type is Optional[bool]:
|
||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||
kwargs["action"] = "store_false" if field.default is True else "store_true"
|
||||
if field.default is True:
|
||||
field_name = f"--no_{field.name}"
|
||||
kwargs["dest"] = field.name
|
||||
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
|
||||
|
||||
# Hack because type=bool in argparse does not behave as we want.
|
||||
kwargs["type"] = string_to_bool
|
||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||
# Default value is True if we have no default when of type bool.
|
||||
default = True if field.default is dataclasses.MISSING else field.default
|
||||
# This is the value that will get picked if we don't include --field_name in any way
|
||||
kwargs["default"] = default
|
||||
# This tells argparse we accept 0 or 1 value after --field_name
|
||||
kwargs["nargs"] = "?"
|
||||
# This is the value that will get picked if we do --field_name (without value)
|
||||
kwargs["const"] = True
|
||||
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
|
||||
kwargs["nargs"] = "+"
|
||||
kwargs["type"] = field.type.__args__[0]
|
||||
|
||||
Reference in New Issue
Block a user