From a97a73e0ee83a113cfd194e3aec1826e6cb054b5 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 30 Sep 2020 12:12:03 -0400 Subject: [PATCH] Small QOL improvements to TrainingArguments (#7475) * Small QOL improvements to TrainingArguments * With the self. --- src/transformers/training_args.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a1f0335646..14ad9d3a3f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -49,8 +49,9 @@ class TrainingArguments: :obj:`output_dir` points to a checkpoint directory. do_train (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to run training or not. - do_eval (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to run evaluation on the dev set or not. + do_eval (:obj:`bool`, `optional`): + Whether to run evaluation on the dev set or not. Will default to :obj:`evaluation_strategy` different from + :obj:`"no"`. do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to run predictions on the test set or not. evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): @@ -183,7 +184,7 @@ class TrainingArguments: ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) - do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) evaluate_during_training: bool = field( default=None, @@ -333,7 +334,8 @@ class TrainingArguments: ) else: self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy) - + if self.do_eval is None: + self.do_eval = self.evaluation_strategy != EvaluationStrategy.NO if self.eval_steps is None: self.eval_steps = self.logging_steps @@ -341,6 +343,8 @@ class TrainingArguments: self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] + if self.run_name is None: + self.run_name = self.output_dir @property def train_batch_size(self) -> int: