[fix] Move _adjust_logits above postprocess to fix Marian.generate (#5126)
This commit is contained in:
@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
if cur_len == 1:
|
if cur_len == 1:
|
||||||
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class MarianMTModel(BartForConditionalGeneration):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
logits[:, self.config.pad_token_id] = float("-inf")
|
logits[:, self.config.pad_token_id] = float("-inf")
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||||
|
|||||||
@@ -792,7 +792,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
def prepare_logits_for_generation(self, logits, **kwargs):
|
def adjust_logits_during_generation(self, logits, **kwargs):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _use_cache(self, outputs, use_cache):
|
def _use_cache(self, outputs, use_cache):
|
||||||
@@ -1396,6 +1396,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# if model has past, then set the past variable to speed up decoding
|
# if model has past, then set the past variable to speed up decoding
|
||||||
if self._use_cache(outputs, use_cache):
|
if self._use_cache(outputs, use_cache):
|
||||||
past = outputs[1]
|
past = outputs[1]
|
||||||
|
if self.config.is_encoder_decoder and do_sample is False:
|
||||||
|
# TODO (PVP) still a bit hacky here - there might be a better solution
|
||||||
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||||
|
)
|
||||||
|
|
||||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
@@ -1413,10 +1418,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder and do_sample is False:
|
|
||||||
# TODO (PVP) still a bit hacky here - there might be a better solution
|
|
||||||
scores = self.prepare_logits_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
|
||||||
|
|
||||||
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
||||||
scores.shape, (batch_size * num_beams, vocab_size)
|
scores.shape, (batch_size * num_beams, vocab_size)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user