Don't use store_xxx on optional bools (#7786)
* Don't use `store_xxx` on optional bools * Refine test * Refine test
This commit is contained in:
@@ -59,7 +59,7 @@ class TorchXLAExamplesTests(unittest.TestCase):
|
|||||||
--model_name_or_path=bert-base-cased
|
--model_name_or_path=bert-base-cased
|
||||||
--per_device_train_batch_size=64
|
--per_device_train_batch_size=64
|
||||||
--per_device_eval_batch_size=64
|
--per_device_eval_batch_size=64
|
||||||
--evaluate_during_training
|
--evaluation_strategy steps
|
||||||
--overwrite_cache
|
--overwrite_cache
|
||||||
""".split()
|
""".split()
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ python run_tf_text_classification.py \
|
|||||||
--do_eval \
|
--do_eval \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--evaluate_during_training \
|
--evaluation_strategy steps \
|
||||||
--save_steps 10 \
|
--save_steps 10 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--max_seq_length 128
|
--max_seq_length 128
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
if field.default is not dataclasses.MISSING:
|
if field.default is not dataclasses.MISSING:
|
||||||
kwargs["default"] = field.default
|
kwargs["default"] = field.default
|
||||||
elif field.type is bool or field.type is Optional[bool]:
|
elif field.type is bool or field.type is Optional[bool]:
|
||||||
|
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||||
kwargs["action"] = "store_false" if field.default is True else "store_true"
|
kwargs["action"] = "store_false" if field.default is True else "store_true"
|
||||||
if field.default is True:
|
if field.default is True:
|
||||||
field_name = f"--no-{field.name}"
|
field_name = f"--no-{field.name}"
|
||||||
|
|||||||
@@ -191,7 +191,7 @@ class TrainingArguments:
|
|||||||
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
|
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||||
evaluate_during_training: bool = field(
|
evaluate_during_training: bool = field(
|
||||||
default=None,
|
default=False,
|
||||||
metadata={"help": "Run evaluation during training at each logging step."},
|
metadata={"help": "Run evaluation during training at each logging step."},
|
||||||
)
|
)
|
||||||
evaluation_strategy: EvaluationStrategy = field(
|
evaluation_strategy: EvaluationStrategy = field(
|
||||||
|
|||||||
@@ -85,7 +85,7 @@
|
|||||||
pass-as: --output_dir={v}
|
pass-as: --output_dir={v}
|
||||||
type: string
|
type: string
|
||||||
default: /valohai/outputs
|
default: /valohai/outputs
|
||||||
- name: evaluate_during_training
|
- name: evaluation_strategy
|
||||||
description: Run evaluation during training at each logging step.
|
description: The evaluation strategy to use.
|
||||||
type: flag
|
type: string
|
||||||
default: true
|
default: steps
|
||||||
|
|||||||
Reference in New Issue
Block a user