Add generate kwargs to Seq2SeqTrainingArguments (#13339)

* Add generate kwargs to Seq2SeqTrainingArguments

* typo

* Address review comments + doc

* Style
This commit is contained in:
Sylvain Gugger
2021-08-31 08:42:00 -04:00
committed by GitHub
parent 702f4a49cd
commit c76de1053e
4 changed files with 41 additions and 23 deletions

View File

@@ -556,12 +556,15 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
@@ -572,10 +575,7 @@ def main():
logger.info("*** Predict ***") logger.info("*** Predict ***")
predict_results = trainer.predict( predict_results = trainer.predict(
predict_dataset, predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
) )
metrics = predict_results.metrics metrics = predict_results.metrics
max_predict_samples = ( max_predict_samples = (

View File

@@ -549,12 +549,16 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate( metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
@@ -565,10 +569,7 @@ def main():
logger.info("*** Predict ***") logger.info("*** Predict ***")
predict_results = trainer.predict( predict_results = trainer.predict(
predict_dataset, predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
) )
metrics = predict_results.metrics metrics = predict_results.metrics
max_predict_samples = ( max_predict_samples = (

View File

@@ -70,10 +70,8 @@ class Seq2SeqTrainer(Trainer):
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state. dictionary also contains the epoch number which comes from the training state.
""" """
if max_length is not None or not hasattr(self, "_max_length"): self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._max_length = max_length self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict( def predict(
@@ -119,10 +117,8 @@ class Seq2SeqTrainer(Trainer):
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels). contained labels).
""" """
if max_length is not None or not hasattr(self, "_max_length"): self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._max_length = max_length self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step( def prediction_step(

View File

@@ -14,6 +14,7 @@
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .training_args import TrainingArguments from .training_args import TrainingArguments
@@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments):
the training set. the training set.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU). Whether to use generate to calculate generative metrics (ROUGE, BLEU).
generation_max_length (:obj:`int`, `optional`):
The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to
the :obj:`max_length` value of the model configuration.
generation_num_beams (:obj:`int`, `optional`):
The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the
:obj:`num_beams` value of the model configuration.
""" """
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
predict_with_generate: bool = field( predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
) )
generation_max_length: Optional[int] = field(
default=None,
metadata={
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `max_length` value of the model configuration."
},
)
generation_num_beams: Optional[int] = field(
default=None,
metadata={
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `num_beams` value of the model configuration."
},
)