use return dict for rag encoder (#9363)
This commit is contained in:
@@ -1437,7 +1437,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
batch_size = context_input_ids.shape[0] // n_docs
|
batch_size = context_input_ids.shape[0] // n_docs
|
||||||
|
|
||||||
encoder = self.rag.generator.get_encoder()
|
encoder = self.rag.generator.get_encoder()
|
||||||
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask)
|
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
|
||||||
|
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(batch_size * num_beams, 1),
|
(batch_size * num_beams, 1),
|
||||||
|
|||||||
Reference in New Issue
Block a user