fix type annotation for debug arg (#24033)
* fix type annotation for debug arg * fix TypeErorr
This commit is contained in:
@@ -873,7 +873,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
debug: str = field(
|
debug: Union[str, List[DebugOption]] = field(
|
||||||
default="",
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -1563,8 +1563,11 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
self.debug += " tpu_metrics_debug"
|
self.debug += " tpu_metrics_debug"
|
||||||
self.tpu_metrics_debug = False
|
self.tpu_metrics_debug = False
|
||||||
|
|
||||||
if isinstance(self.debug, str):
|
if isinstance(self.debug, str):
|
||||||
self.debug = [DebugOption(s) for s in self.debug.split()]
|
self.debug = [DebugOption(s) for s in self.debug.split()]
|
||||||
|
elif self.debug is None:
|
||||||
|
self.debug = []
|
||||||
|
|
||||||
self.deepspeed_plugin = None
|
self.deepspeed_plugin = None
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
|
|||||||
Reference in New Issue
Block a user