From 6701fb7859797132a9c82f56ce34bde8ed0a768f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 4 Mar 2020 15:30:51 +0100 Subject: [PATCH] fix beam_search behavior when sampling (#3106) * fix beam_search behavior when sampling * delete print * make correct style --- src/transformers/modeling_utils.py | 40 +++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e771fd5cc9..3dc0f245c9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -564,7 +564,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): model.eval() if output_loading_info: - loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "error_msgs": error_msgs, + } return model, loading_info return model @@ -941,7 +945,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # scores for each sentence in the beam beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 + + # Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times + if do_sample is False: + beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states @@ -967,19 +974,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: scores = scores / temperature + + scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + # Top-p/top-k filtering - scores = top_k_top_p_filtering( - scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + _scores = top_k_top_p_filtering( + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 ) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together to sample from all beam_idxs + _scores = _scores.contiguous().view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + # Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search) - next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2) + next_words = torch.multinomial( + F.softmax(_scores, dim=-1), num_samples=2 * num_beams + ) # (batch_size, num_beams * 2) + # Compute next scores - _scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) - _scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2) - next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2) - # Match shape of greedy beam search - next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams) - next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams) + next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2) + else: # do greedy beam search scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) @@ -1026,7 +1042,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # add to generated hypotheses if end of sentence or last iteration if eos_token_ids is not None and word_id.item() in eos_token_ids: generated_hyps[batch_idx].add( - input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item() + input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(), ) else: # add next predicted word if it is not eos_token