[finetune_trainer] enhancements and fixes (#9042)
* trainer and finetune_trainer enhancements and fixes * add fallback default * move the fixing of incorrect keys back into finetune trainer * s/eval/val/ to match the split * trainer can now use a different prefix than eval_ for metrics * document new arg * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * use 'eval' as the default for metric_key_prefix * complete adjust var names + disambiguate * fix logger * add clarifying comment * add clarifying comment * style * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/trainer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * complete removal of optional for metric_key_prefix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1243,7 +1243,10 @@ class Trainer:
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
def evaluate(
|
||||
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
|
||||
self,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Run evaluation and returns metrics.
|
||||
@@ -1261,6 +1264,9 @@ class Trainer:
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
|
||||
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
||||
"eval_bleu" if the prefix is "eval" (default)
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||
@@ -1278,6 +1284,7 @@ class Trainer:
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if self.compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
|
||||
self.log(output.metrics)
|
||||
@@ -1289,7 +1296,9 @@ class Trainer:
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
||||
return output.metrics
|
||||
|
||||
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
|
||||
def predict(
|
||||
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval"
|
||||
) -> PredictionOutput:
|
||||
"""
|
||||
Run prediction and returns predictions and potential metrics.
|
||||
|
||||
@@ -1303,6 +1312,9 @@ class Trainer:
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
|
||||
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
||||
"eval_bleu" if the prefix is "eval" (default)
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -1322,7 +1334,9 @@ class Trainer:
|
||||
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
|
||||
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
|
||||
return self.prediction_loop(
|
||||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||
)
|
||||
|
||||
def prediction_loop(
|
||||
self,
|
||||
@@ -1330,6 +1344,7 @@ class Trainer:
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> PredictionOutput:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
||||
@@ -1421,12 +1436,12 @@ class Trainer:
|
||||
metrics = {}
|
||||
|
||||
if eval_loss is not None:
|
||||
metrics["eval_loss"] = eval_loss.mean().item()
|
||||
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
|
||||
|
||||
# Prefix all keys with eval_
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
if not key.startswith("eval_"):
|
||||
metrics[f"eval_{key}"] = metrics.pop(key)
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user