From bb9559a7f98fc18747fd957c3bd2a6e4c1111e45 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 14 Oct 2020 12:05:02 -0400 Subject: [PATCH] Don't use `store_xxx` on optional bools (#7786) * Don't use `store_xxx` on optional bools * Refine test * Refine test --- examples/test_xla_examples.py | 2 +- examples/text-classification/README.md | 2 +- src/transformers/hf_argparser.py | 3 ++- src/transformers/training_args.py | 2 +- valohai.yaml | 8 ++++---- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/test_xla_examples.py b/examples/test_xla_examples.py index 444884ddd8..c5539eee1f 100644 --- a/examples/test_xla_examples.py +++ b/examples/test_xla_examples.py @@ -59,7 +59,7 @@ class TorchXLAExamplesTests(unittest.TestCase): --model_name_or_path=bert-base-cased --per_device_train_batch_size=64 --per_device_eval_batch_size=64 - --evaluate_during_training + --evaluation_strategy steps --overwrite_cache """.split() with patch.object(sys, "argv", testargs): diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index ce412751c1..e7d51fc7d3 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -43,7 +43,7 @@ python run_tf_text_classification.py \ --do_eval \ --do_predict \ --logging_steps 10 \ - --evaluate_during_training \ + --evaluation_strategy steps \ --save_steps 10 \ --overwrite_output_dir \ --max_seq_length 128 diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index f1b4f31526..0b08be85e4 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -65,7 +65,8 @@ class HfArgumentParser(ArgumentParser): if field.default is not dataclasses.MISSING: kwargs["default"] = field.default elif field.type is bool or field.type is Optional[bool]: - kwargs["action"] = "store_false" if field.default is True else "store_true" + if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): + kwargs["action"] = "store_false" if field.default is True else "store_true" if field.default is True: field_name = f"--no-{field.name}" kwargs["dest"] = field.name diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6cfdc15f07..17ea24ba2e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -191,7 +191,7 @@ class TrainingArguments: do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) evaluate_during_training: bool = field( - default=None, + default=False, metadata={"help": "Run evaluation during training at each logging step."}, ) evaluation_strategy: EvaluationStrategy = field( diff --git a/valohai.yaml b/valohai.yaml index 753549ecde..14441e27d0 100644 --- a/valohai.yaml +++ b/valohai.yaml @@ -85,7 +85,7 @@ pass-as: --output_dir={v} type: string default: /valohai/outputs - - name: evaluate_during_training - description: Run evaluation during training at each logging step. - type: flag - default: true + - name: evaluation_strategy + description: The evaluation strategy to use. + type: string + default: steps