[examples/seq2seq] support label smoothing (#9844)

* add prepare_decoder_input_ids_from_labels in s2s models

* support lbl smoothing and enc/emb freezing

* fix freezing

* use pad_token_id from config

* remove embed freezing and add warning

* prepare decoder_input_ids inside DataCollatorForSeq2Seq
This commit is contained in:
Suraj Patil
2021-02-05 23:21:57 +05:30
committed by GitHub
parent b9720dd6f2
commit 1cd16512dc
10 changed files with 46 additions and 1 deletions

View File

@@ -1341,6 +1341,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)

View File

@@ -1207,6 +1207,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)

View File

@@ -2406,6 +2406,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()

View File

@@ -1320,6 +1320,9 @@ class MarianMTModel(MarianPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
if cur_len == max_length - 1 and self.config.eos_token_id is not None:

View File

@@ -1341,6 +1341,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)

View File

@@ -1324,6 +1324,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)

View File

@@ -1852,6 +1852,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
@staticmethod
def _reorder_cache(past, beam_idx):
# this function reorders the cache for beam search

View File

@@ -1608,6 +1608,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder