[RAG] fix generate (#10094)

* fix rag generate and tests

* put back adjust_logits_during_generation

* tests are okay

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2021-02-10 00:27:38 +05:30
committed by GitHub
parent 226973a9c5
commit 3e0c62b611

View File

@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
encoder_no_repeat_ngram_size=None,
repetition_penalty=None, repetition_penalty=None,
bad_words_ids=None, bad_words_ids=None,
num_return_sequences=None, num_return_sequences=None,
@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
order to encourage the model to produce longer sequences. order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once. If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
``decoder_input_ids``.
bad_words_ids(:obj:`List[int]`, `optional`): bad_words_ids(:obj:`List[int]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
pre_processor = self._get_logits_processor( pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,