correct greedy generation when doing beam search (#3078)
* correct greedy generation when doing beam search * improve comment
This commit is contained in:
committed by
GitHub
parent
13afb71208
commit
2fdc7f6ce8
@@ -754,6 +754,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
else:
|
else:
|
||||||
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||||
|
|
||||||
|
if do_sample is False:
|
||||||
|
if num_beams == 1:
|
||||||
|
# no_beam_search greedy generation conditions
|
||||||
|
assert (
|
||||||
|
num_return_sequences == 1
|
||||||
|
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# beam_search greedy generation conditions
|
||||||
|
assert (
|
||||||
|
num_beams >= num_return_sequences
|
||||||
|
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
||||||
|
|
||||||
if pad_token_id is None and eos_token_ids is not None:
|
if pad_token_id is None and eos_token_ids is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||||
@@ -764,7 +777,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
cur_len = input_ids.shape[1]
|
cur_len = input_ids.shape[1]
|
||||||
vocab_size = self.config.vocab_size
|
vocab_size = self.config.vocab_size
|
||||||
|
|
||||||
if num_return_sequences != 1:
|
if num_return_sequences != 1 and do_sample:
|
||||||
# Expand input to num return sequences
|
# Expand input to num return sequences
|
||||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
||||||
input_ids = input_ids.contiguous().view(
|
input_ids = input_ids.contiguous().view(
|
||||||
@@ -787,6 +800,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
effective_batch_size,
|
effective_batch_size,
|
||||||
|
num_return_sequences,
|
||||||
length_penalty,
|
length_penalty,
|
||||||
num_beams,
|
num_beams,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
@@ -826,6 +840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
"""
|
"""
|
||||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||||
|
|
||||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||||
|
|
||||||
@@ -906,12 +921,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
num_return_sequences,
|
||||||
length_penalty,
|
length_penalty,
|
||||||
num_beams,
|
num_beams,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
):
|
):
|
||||||
""" Generate sequences for each example with beam search.
|
""" Generate sequences for each example with beam search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Expand input to num beams
|
# Expand input to num beams
|
||||||
# assert input_ids.shape == (batch_size * num_beams, cur_len)
|
# assert input_ids.shape == (batch_size * num_beams, cur_len)
|
||||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||||
@@ -1057,20 +1074,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
|
||||||
|
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
|
||||||
|
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
|
||||||
|
|
||||||
# select the best hypotheses
|
# select the best hypotheses
|
||||||
sent_lengths = input_ids.new(batch_size)
|
sent_lengths = input_ids.new(output_batch_size)
|
||||||
best = []
|
best = []
|
||||||
|
|
||||||
|
# retrieve best hypotheses
|
||||||
for i, hypotheses in enumerate(generated_hyps):
|
for i, hypotheses in enumerate(generated_hyps):
|
||||||
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
|
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
|
||||||
sent_lengths[i] = len(best_hyp)
|
for j in range(output_num_return_sequences_per_batch):
|
||||||
|
effective_batch_idx = output_num_return_sequences_per_batch * i + j
|
||||||
|
best_hyp = sorted_hyps.pop()[1]
|
||||||
|
sent_lengths[effective_batch_idx] = len(best_hyp)
|
||||||
best.append(best_hyp)
|
best.append(best_hyp)
|
||||||
|
|
||||||
# shorter batches are filled with pad_token
|
# shorter batches are filled with pad_token
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||||
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
|
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
||||||
|
|
||||||
# fill with hypothesis and eos_token_id if necessary
|
# fill with hypothesis and eos_token_id if necessary
|
||||||
for i, hypo in enumerate(best):
|
for i, hypo in enumerate(best):
|
||||||
|
|||||||
@@ -621,10 +621,19 @@ class ModelTesterMixin:
|
|||||||
# batch_size = 1, num_beams > 1
|
# batch_size = 1, num_beams > 1
|
||||||
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
# generating multiple sequences when greedy no beam generation
|
||||||
|
# is not allowed as it would always generate the same sequences
|
||||||
|
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
# generating more sequences than having beams leads is not possible
|
||||||
|
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||||
|
|
||||||
# batch_size > 1, sample
|
# batch_size > 1, sample
|
||||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||||
# batch_size > 1, greedy
|
# batch_size > 1, greedy
|
||||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
|
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||||
# batch_size > 1, num_beams > 1, sample
|
# batch_size > 1, num_beams > 1, sample
|
||||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||||
# batch_size > 1, num_beams > 1, greedy
|
# batch_size > 1, num_beams > 1, greedy
|
||||||
|
|||||||
Reference in New Issue
Block a user