[finetune_trainer] enhancements and fixes (#9042)

* trainer and finetune_trainer enhancements and fixes

* add fallback default

* move the fixing of incorrect keys back into finetune trainer

* s/eval/val/ to match the split

* trainer can now use a different prefix than eval_ for metrics

* document new arg

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* use 'eval' as the default for metric_key_prefix

* complete adjust var names + disambiguate

* fix logger

* add clarifying comment

* add clarifying comment

* style

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* complete removal of optional for metric_key_prefix

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Stas Bekman
2020-12-14 17:45:33 -08:00
committed by GitHub
parent 251eb70c97
commit c19d04623e
3 changed files with 97 additions and 29 deletions

View File

@@ -1243,7 +1243,10 @@ class Trainer:
shutil.rmtree(checkpoint)
def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
@@ -1261,6 +1264,9 @@ class Trainer:
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
@@ -1278,6 +1284,7 @@ class Trainer:
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
self.log(output.metrics)
@@ -1289,7 +1296,9 @@ class Trainer:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval"
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
@@ -1303,6 +1312,9 @@ class Trainer:
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
.. note::
@@ -1322,7 +1334,9 @@ class Trainer:
test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
return self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
def prediction_loop(
self,
@@ -1330,6 +1344,7 @@ class Trainer:
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
@@ -1421,12 +1436,12 @@ class Trainer:
metrics = {}
if eval_loss is not None:
metrics["eval_loss"] = eval_loss.mean().item()
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
# Prefix all keys with eval_
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith("eval_"):
metrics[f"eval_{key}"] = metrics.pop(key)
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)