[RAG] fix generate (#10094)
* fix rag generate and tests * put back adjust_logits_during_generation * tests are okay Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
encoder_no_repeat_ngram_size=None,
|
||||
repetition_penalty=None,
|
||||
bad_words_ids=None,
|
||||
num_return_sequences=None,
|
||||
@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
order to encourage the model to produce longer sequences.
|
||||
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size can only occur once.
|
||||
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
|
||||
``decoder_input_ids``.
|
||||
bad_words_ids(:obj:`List[int]`, `optional`):
|
||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
|
||||
@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
pre_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
encoder_input_ids=context_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
eos_token_id=eos_token_id,
|
||||
|
||||
Reference in New Issue
Block a user