[examples/seq2seq] support label smoothing (#9844)
* add prepare_decoder_input_ids_from_labels in s2s models * support lbl smoothing and enc/emb freezing * fix freezing * use pad_token_id from config * remove embed freezing and add warning * prepare decoder_input_ids inside DataCollatorForSeq2Seq
This commit is contained in:
@@ -384,6 +384,12 @@ def main():
|
|||||||
max_target_length = data_args.max_target_length
|
max_target_length = data_args.max_target_length
|
||||||
padding = "max_length" if data_args.pad_to_max_length else False
|
padding = "max_length" if data_args.pad_to_max_length else False
|
||||||
|
|
||||||
|
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||||
|
logger.warn(
|
||||||
|
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||||
|
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess_function(examples):
|
def preprocess_function(examples):
|
||||||
if data_args.task.startswith("translation"):
|
if data_args.task.startswith("translation"):
|
||||||
inputs = [ex[source_lang] for ex in examples["translation"]]
|
inputs = [ex[source_lang] for ex in examples["translation"]]
|
||||||
@@ -440,6 +446,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
model=model,
|
||||||
label_pad_token_id=label_pad_token_id,
|
label_pad_token_id=label_pad_token_id,
|
||||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
from ..modeling_utils import PreTrainedModel
|
||||||
from ..tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase
|
from ..tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
|
||||||
@@ -232,6 +233,11 @@ class DataCollatorForSeq2Seq:
|
|||||||
Args:
|
Args:
|
||||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||||
The tokenizer used for encoding the data.
|
The tokenizer used for encoding the data.
|
||||||
|
model (:class:`~transformers.PreTrainedModel`):
|
||||||
|
The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to
|
||||||
|
prepare the `decoder_input_ids`
|
||||||
|
|
||||||
|
This is useful when using `label_smoothing` to avoid calculating loss twice.
|
||||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||||
among:
|
among:
|
||||||
@@ -254,6 +260,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
|
model: Optional[PreTrainedModel] = None
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
max_length: Optional[int] = None
|
max_length: Optional[int] = None
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
@@ -272,7 +279,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
|
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.tokenizer.pad(
|
features = self.tokenizer.pad(
|
||||||
features,
|
features,
|
||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
@@ -280,6 +287,13 @@ class DataCollatorForSeq2Seq:
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# prepare decoder_input_ids
|
||||||
|
if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
|
||||||
|
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
|
||||||
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForLanguageModeling:
|
class DataCollatorForLanguageModeling:
|
||||||
|
|||||||
@@ -1341,6 +1341,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||||
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
||||||
|
|||||||
@@ -1207,6 +1207,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||||
|
|||||||
@@ -2406,6 +2406,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1320,6 +1320,9 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
|
|||||||
@@ -1341,6 +1341,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||||
|
|||||||
@@ -1324,6 +1324,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||||
|
|||||||
@@ -1852,6 +1852,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return self._shift_right(labels)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
# this function reorders the cache for beam search
|
# this function reorders the cache for beam search
|
||||||
|
|||||||
@@ -1608,6 +1608,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
|
return self._shift_right(labels)
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
# if decoder past is not included in output
|
# if decoder past is not included in output
|
||||||
# speedy decoding is disabled and no need to reorder
|
# speedy decoding is disabled and no need to reorder
|
||||||
|
|||||||
Reference in New Issue
Block a user