From c37dcff764fa20c048bcff6e9c8cbce5b84151a0 Mon Sep 17 00:00:00 2001 From: Michael Glass <35044941+michaelrglass@users.noreply.github.com> Date: Tue, 26 Jan 2021 10:44:02 -0500 Subject: [PATCH] Fixed parameter name for logits_processor (#9790) --- src/transformers/models/rag/modeling_rag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 2b733cd1b6..e39e84ad5e 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1486,7 +1486,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ) return self.greedy_search( input_ids, - pre_processor=pre_processor, + logits_processor=pre_processor, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, @@ -1509,7 +1509,7 @@ class RagTokenForGeneration(RagPreTrainedModel): return self.beam_search( input_ids, beam_scorer, - pre_processor=pre_processor, + logits_processor=pre_processor, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id,