Diverse beam search 2 (#9006)

* diverse beam search

* bug fixes

* bug fixes

* bug fix

* separate out diverse_beam_search function

* separate out diverse_beam_search function

* bug fix

* improve code quality

* bug fix

* bug fix

* separate out diverse beam search scorer

* code format

* code format

* code format

* code format

* add test

* code format

* documentation changes

* code quality

* add slow integration tests

* more general name

* refactor into logits processor

* add test

* avoid too much copy paste

* refactor

* add to docs

* fix-copies

* bug fix

* Revert "bug fix"

This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4.

* improve comment

* implement sylvains feedback

Co-authored-by: Ayush Jain <a.jain@sprinklr.com>
Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-12-09 15:00:37 +01:00
committed by GitHub
parent 67ff1c314a
commit 02d0e0355c
10 changed files with 590 additions and 20 deletions

View File

@@ -1226,6 +1226,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
early_stopping=None,
use_cache=None,
num_beams=None,
num_beam_groups=None,
diversity_penalty=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
@@ -1302,6 +1304,13 @@ class RagTokenForGeneration(RagPreTrainedModel):
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
num_beams (:obj:`int`, `optional`, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_beam_groups (:obj:`int`, `optional`, defaults to 1):
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
diversity_penalty (:obj:`float`, `optional`, defaults to 0.0):
This value is subtracted from a beam's score if it generates a token same as any beam from other group
at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is
enabled.
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate`
@@ -1326,6 +1335,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
# set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs
num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
max_length = max_length if max_length is not None else self.config.max_length
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
@@ -1412,6 +1422,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
)
if num_beams == 1: