[Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290)
This commit is contained in:
@@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == 1:
|
||||
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
||||
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||
return scores
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
||||
|
||||
Reference in New Issue
Block a user