From 3e0c62b611722e1be0d6cbeba0db7974e7cdc1f0 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 10 Feb 2021 00:27:38 +0530 Subject: [PATCH] [RAG] fix generate (#10094) * fix rag generate and tests * put back adjust_logits_during_generation * tests are okay Co-authored-by: Patrick von Platen --- src/transformers/models/rag/modeling_rag.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 3501720060..b421751d27 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel): eos_token_id=None, length_penalty=None, no_repeat_ngram_size=None, + encoder_no_repeat_ngram_size=None, repetition_penalty=None, bad_words_ids=None, num_return_sequences=None, @@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel): order to encourage the model to produce longer sequences. no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. + encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): + If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the + ``decoder_input_ids``. bad_words_ids(:obj:`List[int]`, `optional`): List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. @@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel): pre_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + encoder_input_ids=context_input_ids, bad_words_ids=bad_words_ids, min_length=min_length, eos_token_id=eos_token_id,