correct greedy generation when doing beam search (#3078)
* correct greedy generation when doing beam search * improve comment
This commit is contained in:
committed by
GitHub
parent
13afb71208
commit
2fdc7f6ce8
@@ -754,6 +754,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
else:
|
||||
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||
|
||||
if do_sample is False:
|
||||
if num_beams == 1:
|
||||
# no_beam_search greedy generation conditions
|
||||
assert (
|
||||
num_return_sequences == 1
|
||||
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
|
||||
|
||||
else:
|
||||
# beam_search greedy generation conditions
|
||||
assert (
|
||||
num_beams >= num_return_sequences
|
||||
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
||||
|
||||
if pad_token_id is None and eos_token_ids is not None:
|
||||
logger.warning(
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||
@@ -764,7 +777,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
cur_len = input_ids.shape[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
if num_return_sequences != 1:
|
||||
if num_return_sequences != 1 and do_sample:
|
||||
# Expand input to num return sequences
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
||||
input_ids = input_ids.contiguous().view(
|
||||
@@ -787,6 +800,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
effective_batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
@@ -826,6 +840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
All returned sequence are generated independantly.
|
||||
"""
|
||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
@@ -906,12 +921,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
|
||||
# Expand input to num beams
|
||||
# assert input_ids.shape == (batch_size * num_beams, cur_len)
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
@@ -1057,20 +1074,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||
)
|
||||
|
||||
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
|
||||
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
|
||||
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size)
|
||||
sent_lengths = input_ids.new(output_batch_size)
|
||||
best = []
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
|
||||
sent_lengths[i] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
|
||||
for j in range(output_num_return_sequences_per_batch):
|
||||
effective_batch_idx = output_num_return_sequences_per_batch * i + j
|
||||
best_hyp = sorted_hyps.pop()[1]
|
||||
sent_lengths[effective_batch_idx] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# shorter batches are filled with pad_token
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
|
||||
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
|
||||
Reference in New Issue
Block a user