From 135703816441bdba59dd5154e293a5946ad538d5 Mon Sep 17 00:00:00 2001 From: Eran Hirsch Date: Wed, 22 Jun 2022 15:11:39 +0300 Subject: [PATCH] Add logits_processor parameter, used by `generate`, to `Seq2SeqTrainer` methods `evaluate` and `predict` (#17805) * Add logits_processor parameter, used by `generate`, to `Seq2SeqTrainer` methods `evaluate` and `predict` * Add all generate parameters to `Seq2SeqTrainer`, and also to `QuestionAnsweringSeq2SeqTrainer` which overrides it * Remove `self._num_beams` from trainer classes * - Run fixup - Fix "Constraint" not exposed - Fix synced_gpus to actually read from param * Use kwargs * Copy kwargs before making changes to it * Fix style issues unused imports --- .../question-answering/trainer_seq2seq_qa.py | 19 +++++-- src/transformers/trainer_seq2seq.py | 50 ++++++++++++++----- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index ac260dadbc..6a5f6da941 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -41,11 +41,16 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): eval_examples=None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", - max_length: Optional[int] = None, - num_beams: Optional[int] = None, + **gen_kwargs, ) -> Dict[str, float]: - self._max_length = max_length if max_length is not None else self.args.generation_max_length - self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + gen_kwargs = gen_kwargs.copy() + gen_kwargs["max_length"] = ( + gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length + ) + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) @@ -87,7 +92,11 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) return metrics - def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): + def predict( + self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test", **gen_kwargs + ): + self._gen_kwargs = gen_kwargs.copy() + predict_dataloader = self.get_test_dataloader(predict_dataset) # Temporarily disable metric computation, we will do it in the loop here. diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 7a290fe149..6dcb387ce7 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -33,8 +33,7 @@ class Seq2SeqTrainer(Trainer): eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", - max_length: Optional[int] = None, - num_beams: Optional[int] = None, + **gen_kwargs ) -> Dict[str, float]: """ Run evaluation and returns metrics. @@ -60,13 +59,23 @@ class Seq2SeqTrainer(Trainer): num_beams (`int`, *optional*): Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. + gen_kwargs: + Additional `generate` specific kwargs. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ - self._max_length = max_length if max_length is not None else self.args.generation_max_length - self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + + gen_kwargs = gen_kwargs.copy() + gen_kwargs["max_length"] = ( + gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length + ) + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) def predict( @@ -74,8 +83,7 @@ class Seq2SeqTrainer(Trainer): test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test", - max_length: Optional[int] = None, - num_beams: Optional[int] = None, + **gen_kwargs ) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. @@ -98,6 +106,8 @@ class Seq2SeqTrainer(Trainer): num_beams (`int`, *optional*): Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. + gen_kwargs: + Additional `generate` specific kwargs. @@ -114,8 +124,16 @@ class Seq2SeqTrainer(Trainer): - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels). """ - self._max_length = max_length if max_length is not None else self.args.generation_max_length - self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + + gen_kwargs = gen_kwargs.copy() + gen_kwargs["max_length"] = ( + gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length + ) + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) def prediction_step( @@ -155,11 +173,17 @@ class Seq2SeqTrainer(Trainer): inputs = self._prepare_inputs(inputs) # XXX: adapt synced_gpus for fairscale as well - gen_kwargs = { - "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, - "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, - "synced_gpus": True if is_deepspeed_zero3_enabled() else False, - } + gen_kwargs = self._gen_kwargs.copy() + gen_kwargs["max_length"] = ( + gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length + ) + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams + ) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs["synced_gpus"] = ( + gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus + ) if "attention_mask" in inputs: gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)