Add generate kwargs to Seq2SeqTrainingArguments (#13339)
* Add generate kwargs to Seq2SeqTrainingArguments * typo * Address review comments + doc * Style
This commit is contained in:
@@ -556,12 +556,15 @@ def main():
|
|||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
|
max_length = (
|
||||||
|
training_args.generation_max_length
|
||||||
|
if training_args.generation_max_length is not None
|
||||||
|
else data_args.val_max_target_length
|
||||||
|
)
|
||||||
|
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
|
||||||
metrics = trainer.evaluate(
|
|
||||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
|
||||||
)
|
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
@@ -572,10 +575,7 @@ def main():
|
|||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
predict_results = trainer.predict(
|
predict_results = trainer.predict(
|
||||||
predict_dataset,
|
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
|
||||||
metric_key_prefix="predict",
|
|
||||||
max_length=data_args.val_max_target_length,
|
|
||||||
num_beams=data_args.num_beams,
|
|
||||||
)
|
)
|
||||||
metrics = predict_results.metrics
|
metrics = predict_results.metrics
|
||||||
max_predict_samples = (
|
max_predict_samples = (
|
||||||
|
|||||||
@@ -549,12 +549,16 @@ def main():
|
|||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
|
max_length = (
|
||||||
|
training_args.generation_max_length
|
||||||
|
if training_args.generation_max_length is not None
|
||||||
|
else data_args.val_max_target_length
|
||||||
|
)
|
||||||
|
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
metrics = trainer.evaluate(
|
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
|
||||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
|
||||||
)
|
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
@@ -565,10 +569,7 @@ def main():
|
|||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
predict_results = trainer.predict(
|
predict_results = trainer.predict(
|
||||||
predict_dataset,
|
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
|
||||||
metric_key_prefix="predict",
|
|
||||||
max_length=data_args.val_max_target_length,
|
|
||||||
num_beams=data_args.num_beams,
|
|
||||||
)
|
)
|
||||||
metrics = predict_results.metrics
|
metrics = predict_results.metrics
|
||||||
max_predict_samples = (
|
max_predict_samples = (
|
||||||
|
|||||||
@@ -70,10 +70,8 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||||
dictionary also contains the epoch number which comes from the training state.
|
dictionary also contains the epoch number which comes from the training state.
|
||||||
"""
|
"""
|
||||||
if max_length is not None or not hasattr(self, "_max_length"):
|
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||||
self._max_length = max_length
|
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||||
if num_beams is not None or not hasattr(self, "_num_beams"):
|
|
||||||
self._num_beams = num_beams
|
|
||||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
@@ -119,10 +117,8 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
|
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
|
||||||
contained labels).
|
contained labels).
|
||||||
"""
|
"""
|
||||||
if max_length is not None or not hasattr(self, "_max_length"):
|
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||||
self._max_length = max_length
|
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||||
if num_beams is not None or not hasattr(self, "_num_beams"):
|
|
||||||
self._num_beams = num_beams
|
|
||||||
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
@@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments):
|
|||||||
the training set.
|
the training set.
|
||||||
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
|
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
|
||||||
|
generation_max_length (:obj:`int`, `optional`):
|
||||||
|
The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to
|
||||||
|
the :obj:`max_length` value of the model configuration.
|
||||||
|
generation_num_beams (:obj:`int`, `optional`):
|
||||||
|
The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the
|
||||||
|
:obj:`num_beams` value of the model configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
|
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
|
||||||
predict_with_generate: bool = field(
|
predict_with_generate: bool = field(
|
||||||
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||||
)
|
)
|
||||||
|
generation_max_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||||||
|
"to the `max_length` value of the model configuration."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
generation_num_beams: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||||||
|
"to the `num_beams` value of the model configuration."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user