From 1cd16512dc8060aa8c2419664f9cb83813ade4d5 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 5 Feb 2021 23:21:57 +0530 Subject: [PATCH] [examples/seq2seq] support label smoothing (#9844) * add prepare_decoder_input_ids_from_labels in s2s models * support lbl smoothing and enc/emb freezing * fix freezing * use pad_token_id from config * remove embed freezing and add warning * prepare decoder_input_ids inside DataCollatorForSeq2Seq --- examples/seq2seq/run_seq2seq.py | 7 +++++++ src/transformers/data/data_collator.py | 16 +++++++++++++++- src/transformers/models/bart/modeling_bart.py | 3 +++ src/transformers/models/fsmt/modeling_fsmt.py | 3 +++ src/transformers/models/led/modeling_led.py | 3 +++ .../models/marian/modeling_marian.py | 3 +++ src/transformers/models/mbart/modeling_mbart.py | 3 +++ .../models/pegasus/modeling_pegasus.py | 3 +++ .../models/prophetnet/modeling_prophetnet.py | 3 +++ src/transformers/models/t5/modeling_t5.py | 3 +++ 10 files changed, 46 insertions(+), 1 deletion(-) diff --git a/examples/seq2seq/run_seq2seq.py b/examples/seq2seq/run_seq2seq.py index 0bbc0de2ec..7b423f6d0d 100644 --- a/examples/seq2seq/run_seq2seq.py +++ b/examples/seq2seq/run_seq2seq.py @@ -384,6 +384,12 @@ def main(): max_target_length = data_args.max_target_length padding = "max_length" if data_args.pad_to_max_length else False + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warn( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + def preprocess_function(examples): if data_args.task.startswith("translation"): inputs = [ex[source_lang] for ex in examples["translation"]] @@ -440,6 +446,7 @@ def main(): else: data_collator = DataCollatorForSeq2Seq( tokenizer, + model=model, label_pad_token_id=label_pad_token_id, pad_to_multiple_of=8 if training_args.fp16 else None, ) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index d585b419e6..530d28306c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union import torch from torch.nn.utils.rnn import pad_sequence +from ..modeling_utils import PreTrainedModel from ..tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase @@ -232,6 +233,11 @@ class DataCollatorForSeq2Seq: Args: tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): The tokenizer used for encoding the data. + model (:class:`~transformers.PreTrainedModel`): + The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to + prepare the `decoder_input_ids` + + This is useful when using `label_smoothing` to avoid calculating loss twice. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: @@ -254,6 +260,7 @@ class DataCollatorForSeq2Seq: """ tokenizer: PreTrainedTokenizerBase + model: Optional[PreTrainedModel] = None padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None @@ -272,7 +279,7 @@ class DataCollatorForSeq2Seq: feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] ) - return self.tokenizer.pad( + features = self.tokenizer.pad( features, padding=self.padding, max_length=self.max_length, @@ -280,6 +287,13 @@ class DataCollatorForSeq2Seq: return_tensors="pt", ) + # prepare decoder_input_ids + if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"]) + features["decoder_input_ids"] = decoder_input_ids + + return features + @dataclass class DataCollatorForLanguageModeling: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 64c968bacc..a595e25d72 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1341,6 +1341,9 @@ class BartForConditionalGeneration(BartPretrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == 1 and self.config.force_bos_token_to_be_generated: self._force_token_id_to_be_generated(logits, self.config.bos_token_id) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 957ba4e84e..c2fcf9c8eb 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1207,6 +1207,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == max_length - 1 and self.config.eos_token_id is not None: self._force_token_ids_generation(logits, self.config.eos_token_id) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 64efdf619a..16fd9a58a1 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2406,6 +2406,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 070387cf98..24011e31f5 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1320,6 +1320,9 @@ class MarianMTModel(MarianPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def adjust_logits_during_generation(self, logits, cur_len, max_length): logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. if cur_len == max_length - 1 and self.config.eos_token_id is not None: diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index a6c03651fe..2aef23d1c0 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1341,6 +1341,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == max_length - 1 and self.config.eos_token_id is not None: self._force_token_id_to_be_generated(logits, self.config.eos_token_id) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 1508b80d46..36f7f13ca0 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1324,6 +1324,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == max_length - 1 and self.config.eos_token_id is not None: self._force_token_id_to_be_generated(logits, self.config.eos_token_id) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index f8cbb11f6f..704e86059c 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1852,6 +1852,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): "use_cache": use_cache, } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + @staticmethod def _reorder_cache(past, beam_idx): # this function reorders the cache for beam search diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index d0f5e5d1a7..6ed8037f17 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1608,6 +1608,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): "use_cache": use_cache, } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + def _reorder_cache(self, past, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder