From 2dd652d757132d97e43173fb048849685ecccb68 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 25 Sep 2020 11:23:55 +0200 Subject: [PATCH] [RAG] Add missing doc and attention_mask to rag (#7382) * add docs * add missing docs and attention_mask in fine-tune --- examples/rag/finetune.py | 1 + src/transformers/modeling_rag.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/examples/rag/finetune.py b/examples/rag/finetune.py index 8648e678b3..4b39724875 100644 --- a/examples/rag/finetune.py +++ b/examples/rag/finetune.py @@ -265,6 +265,7 @@ class GenerativeQAModule(BaseTransformer): start_time = time.time() generated_ids = self.model.generate( batch["input_ids"], + attention_mask=batch["attention_mask"], do_deduplication=False, # rag specific parameter use_cache=True, min_length=1, diff --git a/src/transformers/modeling_rag.py b/src/transformers/modeling_rag.py index 09fe472dd8..c8c3e7eefe 100644 --- a/src/transformers/modeling_rag.py +++ b/src/transformers/modeling_rag.py @@ -831,6 +831,14 @@ class RagSequenceForGeneration(RagPreTrainedModel): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then :obj:`context_input_ids` has to be provided. + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **maked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): Input IDs post-processed from the retrieved documents and the question encoder input_ids by the retriever. @@ -1207,6 +1215,14 @@ class RagTokenForGeneration(RagPreTrainedModel): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then :obj:`context_input_ids` has to be provided. + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **maked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): Input IDs post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the retriever.