From 5f7a07c0c867abedbb3ebf135915eeee56add24b Mon Sep 17 00:00:00 2001 From: Derrick Blakely Date: Sat, 2 Jan 2021 04:39:14 -0700 Subject: [PATCH] use return dict for rag encoder (#9363) --- src/transformers/models/rag/modeling_rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 304a05b85a..2b733cd1b6 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1437,7 +1437,7 @@ class RagTokenForGeneration(RagPreTrainedModel): batch_size = context_input_ids.shape[0] // n_docs encoder = self.rag.generator.get_encoder() - encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask) + encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) input_ids = torch.full( (batch_size * num_beams, 1),