[Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290)
This commit is contained in:
@@ -744,8 +744,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, **kwargs):
|
||||
return scores
|
||||
def prepare_logits_for_generation(self, logits, **kwargs):
|
||||
return logits
|
||||
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
@@ -857,7 +857,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
Defaults to `None`.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
|
||||
decoder_start_token_id=None: (`optional`) int
|
||||
If an encoder-decoder model starts decoding with a different token than BOS.
|
||||
@@ -1342,10 +1342,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
# TODO (PVP) still a bit hacky here - there might be a better solutino
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
||||
# TODO (PVP) still a bit hacky here - there might be a better solution
|
||||
next_token_logits = self.prepare_logits_for_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
|
||||
Reference in New Issue
Block a user