Allow --arg Value for booleans in HfArgumentParser (#9823)

* Allow --arg Value for booleans in HfArgumentParser

* Update last test

* Better error message
This commit is contained in:
Sylvain Gugger
2021-01-27 09:31:42 -05:00
committed by GitHub
parent 35d55b7b84
commit 893120facc
2 changed files with 45 additions and 9 deletions

View File

@@ -20,6 +20,7 @@ from enum import Enum
from typing import List, Optional
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
def list_field(default=None, metadata=None):
@@ -44,6 +45,7 @@ class WithDefaultExample:
class WithDefaultBoolExample:
foo: bool = False
baz: bool = True
opt: Optional[bool] = None
class BasicEnum(Enum):
@@ -91,7 +93,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", action="store_true")
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
self.argparsersEqual(parser, expected)
def test_with_default(self):
@@ -106,15 +108,26 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(WithDefaultBoolExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", action="store_true")
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--no_baz", action="store_false", dest="baz")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
expected.add_argument("--opt", type=string_to_bool, default=None)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True))
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False))
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self):
parser = HfArgumentParser(EnumExample)