From d1d15d6f2de9e2cde48ff3ea2072add3311ce2ac Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 27 Jul 2020 19:40:43 +0530 Subject: [PATCH] [examples (seq2seq)] fix preparing decoder_input_ids for T5 (#5994) --- examples/seq2seq/finetune.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index a0014b9835..1866042cbb 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -14,7 +14,7 @@ import torch from torch.utils.data import DataLoader from lightning_base import BaseTransformer, add_generic_args, generic_train -from transformers import MBartTokenizer, get_linear_schedule_with_warmup +from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup try: @@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer): def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] - decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? - lm_labels = target_ids[:, 1:].clone() # why clone? + + if isinstance(self.model, T5ForConditionalGeneration): + decoder_input_ids = self.model._shift_right(target_ids) + lm_labels = target_ids + else: + decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? + lm_labels = target_ids[:, 1:].clone() # why clone? + outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False) if self.hparams.label_smoothing == 0: