From 2fdc7f6ce8e15793568645b46e4badf7dbe4ecd8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 2 Mar 2020 18:00:09 +0100 Subject: [PATCH] correct greedy generation when doing beam search (#3078) * correct greedy generation when doing beam search * improve comment --- src/transformers/modeling_utils.py | 37 +++++++++++++++++++++++++----- tests/test_modeling_common.py | 11 ++++++++- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index efd622656e..e771fd5cc9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -754,6 +754,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): else: assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." + if do_sample is False: + if num_beams == 1: + # no_beam_search greedy generation conditions + assert ( + num_return_sequences == 1 + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" + + else: + # beam_search greedy generation conditions + assert ( + num_beams >= num_return_sequences + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" + if pad_token_id is None and eos_token_ids is not None: logger.warning( "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) @@ -764,7 +777,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): cur_len = input_ids.shape[1] vocab_size = self.config.vocab_size - if num_return_sequences != 1: + if num_return_sequences != 1 and do_sample: # Expand input to num return sequences input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len) input_ids = input_ids.contiguous().view( @@ -787,6 +800,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): pad_token_id, eos_token_ids, effective_batch_size, + num_return_sequences, length_penalty, num_beams, vocab_size, @@ -826,6 +840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): All returned sequence are generated independantly. """ # current position / max lengths / length of generated sentences / unfinished sentences + unfinished_sents = input_ids.new(batch_size).fill_(1) sent_lengths = input_ids.new(batch_size).fill_(max_length) @@ -906,12 +921,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): pad_token_id, eos_token_ids, batch_size, + num_return_sequences, length_penalty, num_beams, vocab_size, ): """ Generate sequences for each example with beam search. """ + # Expand input to num beams # assert input_ids.shape == (batch_size * num_beams, cur_len) input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) @@ -1057,20 +1074,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item() ) + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch + output_batch_size = batch_size if do_sample else batch_size * num_return_sequences + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences + # select the best hypotheses - sent_lengths = input_ids.new(batch_size) + sent_lengths = input_ids.new(output_batch_size) best = [] + # retrieve best hypotheses for i, hypotheses in enumerate(generated_hyps): - best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1] - sent_lengths[i] = len(best_hyp) - best.append(best_hyp) + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) + for j in range(output_num_return_sequences_per_batch): + effective_batch_idx = output_num_return_sequences_per_batch * i + j + best_hyp = sorted_hyps.pop()[1] + sent_lengths[effective_batch_idx] = len(best_hyp) + best.append(best_hyp) # shorter batches are filled with pad_token if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`Pad_token_id` has to be defined" sent_max_len = min(sent_lengths.max().item() + 1, max_length) - decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id) + decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) # fill with hypothesis and eos_token_id if necessary for i, hypo in enumerate(best): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4e5202a65a..5277864eca 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -621,10 +621,19 @@ class ModelTesterMixin: # batch_size = 1, num_beams > 1 self._check_generated_tokens(model.generate(max_length=5, num_beams=3)) + with self.assertRaises(AssertionError): + # generating multiple sequences when greedy no beam generation + # is not allowed as it would always generate the same sequences + model.generate(input_ids, do_sample=False, num_return_sequences=2) + + with self.assertRaises(AssertionError): + # generating more sequences than having beams leads is not possible + model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) + # batch_size > 1, sample self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) # batch_size > 1, greedy - self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3)) + self._check_generated_tokens(model.generate(input_ids, do_sample=False)) # batch_size > 1, num_beams > 1, sample self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) # batch_size > 1, num_beams > 1, greedy