Diverse beam search 2 (#9006)
* diverse beam search * bug fixes * bug fixes * bug fix * separate out diverse_beam_search function * separate out diverse_beam_search function * bug fix * improve code quality * bug fix * bug fix * separate out diverse beam search scorer * code format * code format * code format * code format * add test * code format * documentation changes * code quality * add slow integration tests * more general name * refactor into logits processor * add test * avoid too much copy paste * refactor * add to docs * fix-copies * bug fix * Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. * improve comment * implement sylvains feedback Co-authored-by: Ayush Jain <a.jain@sprinklr.com> Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
67ff1c314a
commit
02d0e0355c
@@ -95,6 +95,12 @@ class PretrainedConfig(object):
|
||||
sentences are finished per batch or not.
|
||||
- **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by
|
||||
default in the :obj:`generate` method of the model. 1 means no beam search.
|
||||
- **num_beam_groups** (:obj:`int`, `optional`, defaults to 1) -- Number of groups to divide :obj:`num_beams`
|
||||
into in order to ensure diversity among different groups of beams that will be used by default in the
|
||||
:obj:`generate` method of the model. 1 means no group beam search.
|
||||
- **diversity_penalty** (:obj:`float`, `optional`, defaults to 0.0) -- Value to control diversity for group
|
||||
beam search. that will be used by default in the :obj:`generate` method of the model. 0 means no diversity
|
||||
penalty. The higher the penalty, the more diverse are the outputs.
|
||||
- **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
|
||||
probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
|
||||
positive.
|
||||
@@ -185,6 +191,8 @@ class PretrainedConfig(object):
|
||||
self.do_sample = kwargs.pop("do_sample", False)
|
||||
self.early_stopping = kwargs.pop("early_stopping", False)
|
||||
self.num_beams = kwargs.pop("num_beams", 1)
|
||||
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
||||
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
self.top_k = kwargs.pop("top_k", 50)
|
||||
self.top_p = kwargs.pop("top_p", 1.0)
|
||||
|
||||
Reference in New Issue
Block a user