[traner] fix --lr_scheduler_type choices (#9800)
* fix --lr_scheduler_type choices * rewrite to fix for all enum-based cl args * cleanup * adjust test * style * Proposal that should work * Remove needless code * Fix test Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -94,8 +94,8 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
field.type = prim_type
|
field.type = prim_type
|
||||||
|
|
||||||
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
||||||
kwargs["choices"] = list(field.type)
|
kwargs["choices"] = [x.value for x in field.type]
|
||||||
kwargs["type"] = field.type
|
kwargs["type"] = type(kwargs["choices"][0])
|
||||||
if field.default is not dataclasses.MISSING:
|
if field.default is not dataclasses.MISSING:
|
||||||
kwargs["default"] = field.default
|
kwargs["default"] = field.default
|
||||||
elif field.type is bool or field.type is Optional[bool]:
|
elif field.type is bool or field.type is Optional[bool]:
|
||||||
@@ -198,7 +198,7 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
data = json.loads(Path(json_file).read_text())
|
data = json.loads(Path(json_file).read_text())
|
||||||
outputs = []
|
outputs = []
|
||||||
for dtype in self.dataclass_types:
|
for dtype in self.dataclass_types:
|
||||||
keys = {f.name for f in dataclasses.fields(dtype)}
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||||
inputs = {k: v for k, v in data.items() if k in keys}
|
inputs = {k: v for k, v in data.items() if k in keys}
|
||||||
obj = dtype(**inputs)
|
obj = dtype(**inputs)
|
||||||
outputs.append(obj)
|
outputs.append(obj)
|
||||||
@@ -211,7 +211,7 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
for dtype in self.dataclass_types:
|
for dtype in self.dataclass_types:
|
||||||
keys = {f.name for f in dataclasses.fields(dtype)}
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||||
inputs = {k: v for k, v in args.items() if k in keys}
|
inputs = {k: v for k, v in args.items() if k in keys}
|
||||||
obj = dtype(**inputs)
|
obj = dtype(**inputs)
|
||||||
outputs.append(obj)
|
outputs.append(obj)
|
||||||
|
|||||||
@@ -55,7 +55,10 @@ class BasicEnum(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnumExample:
|
class EnumExample:
|
||||||
foo: BasicEnum = BasicEnum.toto
|
foo: BasicEnum = "toto"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.foo = BasicEnum(self.foo)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -133,14 +136,18 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
parser = HfArgumentParser(EnumExample)
|
parser = HfArgumentParser(EnumExample)
|
||||||
|
|
||||||
expected = argparse.ArgumentParser()
|
expected = argparse.ArgumentParser()
|
||||||
expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum)
|
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
|
||||||
self.argparsersEqual(parser, expected)
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
self.assertEqual(args.foo, BasicEnum.toto)
|
self.assertEqual(args.foo, "toto")
|
||||||
|
enum_ex = parser.parse_args_into_dataclasses([])[0]
|
||||||
|
self.assertEqual(enum_ex.foo, BasicEnum.toto)
|
||||||
|
|
||||||
args = parser.parse_args(["--foo", "titi"])
|
args = parser.parse_args(["--foo", "titi"])
|
||||||
self.assertEqual(args.foo, BasicEnum.titi)
|
self.assertEqual(args.foo, "titi")
|
||||||
|
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
|
||||||
|
self.assertEqual(enum_ex.foo, BasicEnum.titi)
|
||||||
|
|
||||||
def test_with_list(self):
|
def test_with_list(self):
|
||||||
parser = HfArgumentParser(ListExample)
|
parser = HfArgumentParser(ListExample)
|
||||||
|
|||||||
Reference in New Issue
Block a user