do not allow do_sample and weird force bos token things
This commit is contained in:
committed by
Patrick von Platen
parent
d880a5fbde
commit
629aac92ec
@@ -962,7 +962,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
# Sample
|
# Sample
|
||||||
next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1)
|
probs = F.softmax(next_token_logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
else:
|
else:
|
||||||
# Greedy decoding
|
# Greedy decoding
|
||||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||||
@@ -1079,8 +1080,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
if (
|
if (
|
||||||
self.config.is_encoder_decoder
|
self.config.is_encoder_decoder and do_sample is False
|
||||||
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search?
|
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
|
||||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
@@ -1114,9 +1115,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
) # (batch_size, num_beams * vocab_size)
|
) # (batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||||
next_tokens = torch.multinomial(
|
probs = F.softmax(_scores, dim=-1)
|
||||||
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
|
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
|
||||||
) # (batch_size, num_beams * 2)
|
|
||||||
# Compute next scores
|
# Compute next scores
|
||||||
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
|
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
|
||||||
# sort the sampled vector to make sure that the first num_beams samples are the best
|
# sort the sampled vector to make sure that the first num_beams samples are the best
|
||||||
|
|||||||
Reference in New Issue
Block a user