[RAG] fix generate (#10094)

* fix rag generate and tests

* put back adjust_logits_during_generation

* tests are okay

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2021-02-10 00:27:38 +05:30
committed by GitHub
parent 226973a9c5
commit 3e0c62b611

View File

@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
encoder_no_repeat_ngram_size=None,
repetition_penalty=None,
bad_words_ids=None,
num_return_sequences=None,
@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
``decoder_input_ids``.
bad_words_ids(:obj:`List[int]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,