Diverse beam search 2 (#9006)

* diverse beam search

* bug fixes

* bug fixes

* bug fix

* separate out diverse_beam_search function

* separate out diverse_beam_search function

* bug fix

* improve code quality

* bug fix

* bug fix

* separate out diverse beam search scorer

* code format

* code format

* code format

* code format

* add test

* code format

* documentation changes

* code quality

* add slow integration tests

* more general name

* refactor into logits processor

* add test

* avoid too much copy paste

* refactor

* add to docs

* fix-copies

* bug fix

* Revert "bug fix"

This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4.

* improve comment

* implement sylvains feedback

Co-authored-by: Ayush Jain <a.jain@sprinklr.com>
Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-12-09 15:00:37 +01:00
committed by GitHub
parent 67ff1c314a
commit 02d0e0355c
10 changed files with 590 additions and 20 deletions

View File

@@ -95,6 +95,12 @@ class PretrainedConfig(object):
sentences are finished per batch or not.
- **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by
default in the :obj:`generate` method of the model. 1 means no beam search.
- **num_beam_groups** (:obj:`int`, `optional`, defaults to 1) -- Number of groups to divide :obj:`num_beams`
into in order to ensure diversity among different groups of beams that will be used by default in the
:obj:`generate` method of the model. 1 means no group beam search.
- **diversity_penalty** (:obj:`float`, `optional`, defaults to 0.0) -- Value to control diversity for group
beam search. that will be used by default in the :obj:`generate` method of the model. 0 means no diversity
penalty. The higher the penalty, the more diverse are the outputs.
- **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
positive.
@@ -185,6 +191,8 @@ class PretrainedConfig(object):
self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)