Add logits_processor parameter, used by generate, to Seq2SeqTrainer methods evaluate and predict (#17805)
* Add logits_processor parameter, used by `generate`, to `Seq2SeqTrainer` methods `evaluate` and `predict` * Add all generate parameters to `Seq2SeqTrainer`, and also to `QuestionAnsweringSeq2SeqTrainer` which overrides it * Remove `self._num_beams` from trainer classes * - Run fixup - Fix "Constraint" not exposed - Fix synced_gpus to actually read from param * Use kwargs * Copy kwargs before making changes to it * Fix style issues unused imports
This commit is contained in:
@@ -41,11 +41,16 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
eval_examples=None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
max_length: Optional[int] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
**gen_kwargs,
|
||||
) -> Dict[str, float]:
|
||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||
gen_kwargs = gen_kwargs.copy()
|
||||
gen_kwargs["max_length"] = (
|
||||
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
|
||||
)
|
||||
gen_kwargs["num_beams"] = (
|
||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
||||
)
|
||||
self._gen_kwargs = gen_kwargs
|
||||
|
||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
@@ -87,7 +92,11 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
return metrics
|
||||
|
||||
def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
|
||||
def predict(
|
||||
self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test", **gen_kwargs
|
||||
):
|
||||
self._gen_kwargs = gen_kwargs.copy()
|
||||
|
||||
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
||||
|
||||
# Temporarily disable metric computation, we will do it in the loop here.
|
||||
|
||||
Reference in New Issue
Block a user