fix beam_search behavior when sampling (#3106)
* fix beam_search behavior when sampling * delete print * make correct style
This commit is contained in:
committed by
GitHub
parent
e9e6efdc45
commit
6701fb7859
@@ -564,7 +564,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if output_loading_info:
|
if output_loading_info:
|
||||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
|
loading_info = {
|
||||||
|
"missing_keys": missing_keys,
|
||||||
|
"unexpected_keys": unexpected_keys,
|
||||||
|
"error_msgs": error_msgs,
|
||||||
|
}
|
||||||
return model, loading_info
|
return model, loading_info
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -941,7 +945,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# scores for each sentence in the beam
|
# scores for each sentence in the beam
|
||||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
beam_scores[:, 1:] = -1e9
|
|
||||||
|
# Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times
|
||||||
|
if do_sample is False:
|
||||||
|
beam_scores[:, 1:] = -1e9
|
||||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||||
|
|
||||||
# cache compute states
|
# cache compute states
|
||||||
@@ -967,19 +974,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
scores = scores / temperature
|
scores = scores / temperature
|
||||||
|
|
||||||
|
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
scores = top_k_top_p_filtering(
|
_scores = top_k_top_p_filtering(
|
||||||
scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||||
) # (batch_size * num_beams, vocab_size)
|
) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
# re-organize to group the beam together to sample from all beam_idxs
|
||||||
|
_scores = _scores.contiguous().view(
|
||||||
|
batch_size, num_beams * vocab_size
|
||||||
|
) # (batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
|
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||||
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
|
next_words = torch.multinomial(
|
||||||
|
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
|
||||||
|
) # (batch_size, num_beams * 2)
|
||||||
|
|
||||||
# Compute next scores
|
# Compute next scores
|
||||||
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2)
|
||||||
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
|
|
||||||
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
|
|
||||||
# Match shape of greedy beam search
|
|
||||||
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
|
|
||||||
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
|
|
||||||
else:
|
else:
|
||||||
# do greedy beam search
|
# do greedy beam search
|
||||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
@@ -1026,7 +1042,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# add to generated hypotheses if end of sentence or last iteration
|
# add to generated hypotheses if end of sentence or last iteration
|
||||||
if eos_token_ids is not None and word_id.item() in eos_token_ids:
|
if eos_token_ids is not None and word_id.item() in eos_token_ids:
|
||||||
generated_hyps[batch_idx].add(
|
generated_hyps[batch_idx].add(
|
||||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted word if it is not eos_token
|
# add next predicted word if it is not eos_token
|
||||||
|
|||||||
Reference in New Issue
Block a user