[fix] Move _adjust_logits above postprocess to fix Marian.generate (#5126)

This commit is contained in:
Sam Shleifer
2020-06-18 18:06:27 -04:00
committed by GitHub
parent 3d3e605aff
commit 8a377c3d6e
3 changed files with 8 additions and 7 deletions

View File

@@ -792,7 +792,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def prepare_logits_for_generation(self, logits, **kwargs):
def adjust_logits_during_generation(self, logits, **kwargs):
return logits
def _use_cache(self, outputs, use_cache):
@@ -1396,6 +1396,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
@@ -1413,10 +1418,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_beams=num_beams,
)
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
scores = self.prepare_logits_for_generation(scores, cur_len=cur_len, max_length=max_length)
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size)
)