[Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290)

This commit is contained in:
Sam Shleifer
2020-05-13 17:29:41 -04:00
committed by GitHub
parent 839bfaedb2
commit 9a687ebb77
7 changed files with 274 additions and 82 deletions

View File

@@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
def prepare_logits_for_generation(self, logits, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
self._force_token_ids_generation(logits, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""