Diverse beam search 2 (#9006)
* diverse beam search * bug fixes * bug fixes * bug fix * separate out diverse_beam_search function * separate out diverse_beam_search function * bug fix * improve code quality * bug fix * bug fix * separate out diverse beam search scorer * code format * code format * code format * code format * add test * code format * documentation changes * code quality * add slow integration tests * more general name * refactor into logits processor * add test * avoid too much copy paste * refactor * add to docs * fix-copies * bug fix * Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. * improve comment * implement sylvains feedback Co-authored-by: Ayush Jain <a.jain@sprinklr.com> Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
67ff1c314a
commit
02d0e0355c
@@ -1226,6 +1226,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
early_stopping=None,
|
||||
use_cache=None,
|
||||
num_beams=None,
|
||||
num_beam_groups=None,
|
||||
diversity_penalty=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_id=None,
|
||||
@@ -1302,6 +1304,13 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
|
||||
num_beams (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
num_beam_groups (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
||||
beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
diversity_penalty (:obj:`float`, `optional`, defaults to 0.0):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group
|
||||
at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is
|
||||
enabled.
|
||||
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate`
|
||||
@@ -1326,6 +1335,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
# set default parameters
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
@@ -1412,6 +1422,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
eos_token_id=eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
)
|
||||
|
||||
if num_beams == 1:
|
||||
|
||||
Reference in New Issue
Block a user