[RAG] fix generate (#10094)
* fix rag generate and tests * put back adjust_logits_during_generation * tests are okay Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
length_penalty=None,
|
length_penalty=None,
|
||||||
no_repeat_ngram_size=None,
|
no_repeat_ngram_size=None,
|
||||||
|
encoder_no_repeat_ngram_size=None,
|
||||||
repetition_penalty=None,
|
repetition_penalty=None,
|
||||||
bad_words_ids=None,
|
bad_words_ids=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
order to encourage the model to produce longer sequences.
|
order to encourage the model to produce longer sequences.
|
||||||
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||||
If set to int > 0, all ngrams of that size can only occur once.
|
If set to int > 0, all ngrams of that size can only occur once.
|
||||||
|
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
|
||||||
|
``decoder_input_ids``.
|
||||||
bad_words_ids(:obj:`List[int]`, `optional`):
|
bad_words_ids(:obj:`List[int]`, `optional`):
|
||||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||||
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
|
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
|
||||||
@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
pre_processor = self._get_logits_processor(
|
pre_processor = self._get_logits_processor(
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
|
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||||
|
encoder_input_ids=context_input_ids,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user