Add logits_processor parameter, used by generate, to Seq2SeqTrainer methods evaluate and predict (#17805)
* Add logits_processor parameter, used by `generate`, to `Seq2SeqTrainer` methods `evaluate` and `predict` * Add all generate parameters to `Seq2SeqTrainer`, and also to `QuestionAnsweringSeq2SeqTrainer` which overrides it * Remove `self._num_beams` from trainer classes * - Run fixup - Fix "Constraint" not exposed - Fix synced_gpus to actually read from param * Use kwargs * Copy kwargs before making changes to it * Fix style issues unused imports
This commit is contained in:
@@ -41,11 +41,16 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
eval_examples=None,
|
eval_examples=None,
|
||||||
ignore_keys: Optional[List[str]] = None,
|
ignore_keys: Optional[List[str]] = None,
|
||||||
metric_key_prefix: str = "eval",
|
metric_key_prefix: str = "eval",
|
||||||
max_length: Optional[int] = None,
|
**gen_kwargs,
|
||||||
num_beams: Optional[int] = None,
|
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
gen_kwargs = gen_kwargs.copy()
|
||||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
gen_kwargs["max_length"] = (
|
||||||
|
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
|
||||||
|
)
|
||||||
|
gen_kwargs["num_beams"] = (
|
||||||
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
||||||
|
)
|
||||||
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||||
@@ -87,7 +92,11 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
|
def predict(
|
||||||
|
self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test", **gen_kwargs
|
||||||
|
):
|
||||||
|
self._gen_kwargs = gen_kwargs.copy()
|
||||||
|
|
||||||
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
||||||
|
|
||||||
# Temporarily disable metric computation, we will do it in the loop here.
|
# Temporarily disable metric computation, we will do it in the loop here.
|
||||||
|
|||||||
@@ -33,8 +33,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Dataset] = None,
|
||||||
ignore_keys: Optional[List[str]] = None,
|
ignore_keys: Optional[List[str]] = None,
|
||||||
metric_key_prefix: str = "eval",
|
metric_key_prefix: str = "eval",
|
||||||
max_length: Optional[int] = None,
|
**gen_kwargs
|
||||||
num_beams: Optional[int] = None,
|
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Run evaluation and returns metrics.
|
Run evaluation and returns metrics.
|
||||||
@@ -60,13 +59,23 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
num_beams (`int`, *optional*):
|
num_beams (`int`, *optional*):
|
||||||
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
||||||
beam search.
|
beam search.
|
||||||
|
gen_kwargs:
|
||||||
|
Additional `generate` specific kwargs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||||
dictionary also contains the epoch number which comes from the training state.
|
dictionary also contains the epoch number which comes from the training state.
|
||||||
"""
|
"""
|
||||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
|
||||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
gen_kwargs = gen_kwargs.copy()
|
||||||
|
gen_kwargs["max_length"] = (
|
||||||
|
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
|
||||||
|
)
|
||||||
|
gen_kwargs["num_beams"] = (
|
||||||
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
||||||
|
)
|
||||||
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
@@ -74,8 +83,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
test_dataset: Dataset,
|
test_dataset: Dataset,
|
||||||
ignore_keys: Optional[List[str]] = None,
|
ignore_keys: Optional[List[str]] = None,
|
||||||
metric_key_prefix: str = "test",
|
metric_key_prefix: str = "test",
|
||||||
max_length: Optional[int] = None,
|
**gen_kwargs
|
||||||
num_beams: Optional[int] = None,
|
|
||||||
) -> PredictionOutput:
|
) -> PredictionOutput:
|
||||||
"""
|
"""
|
||||||
Run prediction and returns predictions and potential metrics.
|
Run prediction and returns predictions and potential metrics.
|
||||||
@@ -98,6 +106,8 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
num_beams (`int`, *optional*):
|
num_beams (`int`, *optional*):
|
||||||
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
||||||
beam search.
|
beam search.
|
||||||
|
gen_kwargs:
|
||||||
|
Additional `generate` specific kwargs.
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
@@ -114,8 +124,16 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
|
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
|
||||||
labels).
|
labels).
|
||||||
"""
|
"""
|
||||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
|
||||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
gen_kwargs = gen_kwargs.copy()
|
||||||
|
gen_kwargs["max_length"] = (
|
||||||
|
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
|
||||||
|
)
|
||||||
|
gen_kwargs["num_beams"] = (
|
||||||
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
||||||
|
)
|
||||||
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
@@ -155,11 +173,17 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
# XXX: adapt synced_gpus for fairscale as well
|
# XXX: adapt synced_gpus for fairscale as well
|
||||||
gen_kwargs = {
|
gen_kwargs = self._gen_kwargs.copy()
|
||||||
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
|
gen_kwargs["max_length"] = (
|
||||||
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length
|
||||||
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
|
)
|
||||||
}
|
gen_kwargs["num_beams"] = (
|
||||||
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
|
||||||
|
)
|
||||||
|
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
|
||||||
|
gen_kwargs["synced_gpus"] = (
|
||||||
|
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
|
||||||
|
)
|
||||||
|
|
||||||
if "attention_mask" in inputs:
|
if "attention_mask" in inputs:
|
||||||
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
|
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
|
||||||
|
|||||||
Reference in New Issue
Block a user