[traner] fix --lr_scheduler_type choices (#9800)
* fix --lr_scheduler_type choices * rewrite to fix for all enum-based cl args * cleanup * adjust test * style * Proposal that should work * Remove needless code * Fix test Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -94,8 +94,8 @@ class HfArgumentParser(ArgumentParser):
|
||||
field.type = prim_type
|
||||
|
||||
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
||||
kwargs["choices"] = list(field.type)
|
||||
kwargs["type"] = field.type
|
||||
kwargs["choices"] = [x.value for x in field.type]
|
||||
kwargs["type"] = type(kwargs["choices"][0])
|
||||
if field.default is not dataclasses.MISSING:
|
||||
kwargs["default"] = field.default
|
||||
elif field.type is bool or field.type is Optional[bool]:
|
||||
@@ -198,7 +198,7 @@ class HfArgumentParser(ArgumentParser):
|
||||
data = json.loads(Path(json_file).read_text())
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
keys = {f.name for f in dataclasses.fields(dtype)}
|
||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||
inputs = {k: v for k, v in data.items() if k in keys}
|
||||
obj = dtype(**inputs)
|
||||
outputs.append(obj)
|
||||
@@ -211,7 +211,7 @@ class HfArgumentParser(ArgumentParser):
|
||||
"""
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
keys = {f.name for f in dataclasses.fields(dtype)}
|
||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||
inputs = {k: v for k, v in args.items() if k in keys}
|
||||
obj = dtype(**inputs)
|
||||
outputs.append(obj)
|
||||
|
||||
Reference in New Issue
Block a user