From 664c7ec453593f0924cad3396709d52fea68c4f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 26 Oct 2020 17:28:16 +0100 Subject: [PATCH] [Seq2Seq Trainer] Make sure padding is implemented for models without pad_token (#8043) * make sure padding is implemented for non-padding tokens models as well * add better error message * add better warning * remove results files * Update examples/seq2seq/seq2seq_trainer.py * remove unnecessary copy line * correct usage of labels * delete test files --- examples/seq2seq/seq2seq_trainer.py | 51 ++++++++++++++--------- examples/seq2seq/test_finetune_trainer.py | 2 + 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 39ca2c9cd2..1b3d52ad44 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -1,4 +1,3 @@ -import copy from typing import Any, Dict, Optional, Tuple, Union import torch @@ -60,6 +59,11 @@ class Seq2SeqTrainer(Trainer): self.config.pad_token_id is not None ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." + if self.config.pad_token_id is None and self.config.eos_token_id is not None: + logger.warn( + f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." + ) + def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. @@ -126,22 +130,19 @@ class Seq2SeqTrainer(Trainer): else DistributedSampler(self.train_dataset) ) - def _compute_loss(self, model, inputs): - inputs = copy.deepcopy(inputs) + def _compute_loss(self, model, inputs, labels): if self.args.label_smoothing == 0: if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: # force training to ignore pad token - labels = inputs.pop("labels") logits = model(**inputs, use_cache=False)[0] loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) else: # compute usual loss via models - loss, logits = model(**inputs, use_cache=False)[:2] + loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] else: # compute label smoothed loss - labels = inputs.pop("labels") logits = model(**inputs, use_cache=False)[0] lprobs = torch.nn.functional.log_softmax(logits, dim=-1) loss, _ = label_smoothed_nll_loss( @@ -150,7 +151,8 @@ class Seq2SeqTrainer(Trainer): return loss, logits def compute_loss(self, model, inputs): - loss, _ = self._compute_loss(model, inputs) + labels = inputs.pop("labels") + loss, _ = self._compute_loss(model, inputs, labels) return loss def prediction_step( @@ -178,25 +180,27 @@ class Seq2SeqTrainer(Trainer): """ inputs = self._prepare_inputs(inputs) + gen_kwargs = { + "max_length": self.data_args.val_max_target_length + if self.data_args is not None + else self.config.max_length, + "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, + } + if self.args.predict_with_generate and not self.args.prediction_loss_only: - gen_kwargs = { - "max_length": self.data_args.val_max_target_length - if self.data_args is not None - else self.config.max_length, - "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, - } generated_tokens = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], **gen_kwargs, ) # in case the batch is shorter than max length, the output should be padded - if self.config.pad_token_id is not None: + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) - # compute loss on predict data + labels = inputs.pop("labels") with torch.no_grad(): - loss, logits = self._compute_loss(model, inputs) + # compute loss on predict data + loss, logits = self._compute_loss(model, inputs, labels) loss = loss.mean().detach() if self.args.prediction_loss_only: @@ -204,14 +208,21 @@ class Seq2SeqTrainer(Trainer): logits = generated_tokens if self.args.predict_with_generate else logits - labels = inputs["labels"] - if self.config.pad_token_id is not None: - labels = self._pad_tensors_to_max_len(labels, self.config.max_length) + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) return (loss, logits, labels) def _pad_tensors_to_max_len(self, tensor, max_length): - padded_tensor = self.config.pad_token_id * torch.ones( + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id + + if pad_token_id is None: + raise ValueError( + f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}" + ) + + padded_tensor = pad_token_id * torch.ones( (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device ) padded_tensor[:, : tensor.shape[-1]] = tensor diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 8a0cdf3aa7..f2c879d09d 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -63,7 +63,9 @@ class TestFinetuneTrainer(TestCasePlus): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size + bert2bert.config.eos_token_id = tokenizer.sep_token_id bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id + bert2bert.config.max_length = 128 train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")