From 8a377c3d6e20f8e38b9848892baf9baddd168dc5 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Jun 2020 18:06:27 -0400 Subject: [PATCH] [fix] Move _adjust_logits above postprocess to fix Marian.generate (#5126) --- src/transformers/modeling_bart.py | 2 +- src/transformers/modeling_marian.py | 2 +- src/transformers/modeling_utils.py | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 160981bb53..d8babc3a24 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - def prepare_logits_for_generation(self, logits, cur_len, max_length): + def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == 1: self._force_token_ids_generation(logits, self.config.bos_token_id) if cur_len == max_length - 1 and self.config.eos_token_id is not None: diff --git a/src/transformers/modeling_marian.py b/src/transformers/modeling_marian.py index 65ef91f64d..e8598ca1f6 100644 --- a/src/transformers/modeling_marian.py +++ b/src/transformers/modeling_marian.py @@ -46,7 +46,7 @@ class MarianMTModel(BartForConditionalGeneration): """ - def prepare_logits_for_generation(self, logits, cur_len, max_length): + def adjust_logits_during_generation(self, logits, cur_len, max_length): logits[:, self.config.pad_token_id] = float("-inf") if cur_len == max_length - 1 and self.config.eos_token_id is not None: self._force_token_ids_generation(logits, self.config.eos_token_id) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index da072491e9..bb9c53d7d2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -792,7 +792,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} - def prepare_logits_for_generation(self, logits, **kwargs): + def adjust_logits_during_generation(self, logits, **kwargs): return logits def _use_cache(self, outputs, use_cache): @@ -1396,6 +1396,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # if model has past, then set the past variable to speed up decoding if self._use_cache(outputs, use_cache): past = outputs[1] + if self.config.is_encoder_decoder and do_sample is False: + # TODO (PVP) still a bit hacky here - there might be a better solution + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) @@ -1413,10 +1418,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): num_beams=num_beams, ) - if self.config.is_encoder_decoder and do_sample is False: - # TODO (PVP) still a bit hacky here - there might be a better solution - scores = self.prepare_logits_for_generation(scores, cur_len=cur_len, max_length=max_length) - assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( scores.shape, (batch_size * num_beams, vocab_size) )