From 629aac92ec33c664b8470221db0a3d1921ccdc55 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Sat, 7 Mar 2020 11:45:45 +0100 Subject: [PATCH] do not allow do_sample and weird force bos token things --- src/transformers/modeling_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a1a0306684..26b6cb3968 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -962,7 +962,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # Top-p/top-k filtering next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) # Sample - next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1) + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) else: # Greedy decoding next_token = torch.argmax(next_token_logits, dim=-1) @@ -1079,8 +1080,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) if ( - self.config.is_encoder_decoder - ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? + self.config.is_encoder_decoder and do_sample is False + ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here scores = self.prepare_scores_for_generation(scores, cur_len, max_length) # set eos token prob to zero if min_length is not reached @@ -1114,9 +1115,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) # (batch_size, num_beams * vocab_size) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) - next_tokens = torch.multinomial( - F.softmax(_scores, dim=-1), num_samples=2 * num_beams - ) # (batch_size, num_beams * 2) + probs = F.softmax(_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) # Compute next scores next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) # sort the sampled vector to make sure that the first num_beams samples are the best