[traner] fix --lr_scheduler_type choices (#9800)
* fix --lr_scheduler_type choices * rewrite to fix for all enum-based cl args * cleanup * adjust test * style * Proposal that should work * Remove needless code * Fix test Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -55,7 +55,10 @@ class BasicEnum(Enum):
|
||||
|
||||
@dataclass
|
||||
class EnumExample:
|
||||
foo: BasicEnum = BasicEnum.toto
|
||||
foo: BasicEnum = "toto"
|
||||
|
||||
def __post_init__(self):
|
||||
self.foo = BasicEnum(self.foo)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -133,14 +136,18 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
parser = HfArgumentParser(EnumExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum)
|
||||
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args.foo, BasicEnum.toto)
|
||||
self.assertEqual(args.foo, "toto")
|
||||
enum_ex = parser.parse_args_into_dataclasses([])[0]
|
||||
self.assertEqual(enum_ex.foo, BasicEnum.toto)
|
||||
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, BasicEnum.titi)
|
||||
self.assertEqual(args.foo, "titi")
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
|
||||
self.assertEqual(enum_ex.foo, BasicEnum.titi)
|
||||
|
||||
def test_with_list(self):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
|
||||
Reference in New Issue
Block a user