[RAG] Add missing doc and attention_mask to rag (#7382)
* add docs * add missing docs and attention_mask in fine-tune
This commit is contained in:
committed by
GitHub
parent
7cdd9da5bf
commit
2dd652d757
@@ -265,6 +265,7 @@ class GenerativeQAModule(BaseTransformer):
|
||||
start_time = time.time()
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
do_deduplication=False, # rag specific parameter
|
||||
use_cache=True,
|
||||
min_length=1,
|
||||
|
||||
@@ -831,6 +831,14 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then
|
||||
:obj:`context_input_ids` has to be provided.
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **maked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
|
||||
Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
|
||||
retriever.
|
||||
@@ -1207,6 +1215,14 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then
|
||||
:obj:`context_input_ids` has to be provided.
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **maked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
|
||||
Input IDs post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the
|
||||
retriever.
|
||||
|
||||
Reference in New Issue
Block a user