Allow --arg Value for booleans in HfArgumentParser (#9823)
* Allow --arg Value for booleans in HfArgumentParser * Update last test * Better error message
This commit is contained in:
@@ -20,6 +20,7 @@ from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import string_to_bool
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
@@ -44,6 +45,7 @@ class WithDefaultExample:
|
||||
class WithDefaultBoolExample:
|
||||
foo: bool = False
|
||||
baz: bool = True
|
||||
opt: Optional[bool] = None
|
||||
|
||||
|
||||
class BasicEnum(Enum):
|
||||
@@ -91,7 +93,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--foo", type=int, required=True)
|
||||
expected.add_argument("--bar", type=float, required=True)
|
||||
expected.add_argument("--baz", type=str, required=True)
|
||||
expected.add_argument("--flag", action="store_true")
|
||||
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_with_default(self):
|
||||
@@ -106,15 +108,26 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
parser = HfArgumentParser(WithDefaultBoolExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", action="store_true")
|
||||
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
|
||||
expected.add_argument("--no_baz", action="store_false", dest="baz")
|
||||
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=True))
|
||||
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--no_baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False))
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
|
||||
|
||||
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
||||
|
||||
def test_with_enum(self):
|
||||
parser = HfArgumentParser(EnumExample)
|
||||
|
||||
Reference in New Issue
Block a user