Allow Custom Dataset in RAG Retriever (#7763)

* add CustomHFIndex

* typo in config

* update tests

* add custom dataset example

* clean script

* update test data

* minor in test

* docs

* docs

* style

* fix imports

* allow to pass the indexed dataset directly

* update tests

* use multiset DPR

* address thom and patrick's comments

* style

* update dpr tokenizer

* add output_dir flag in use_own_knowledge_dataset.py

* allow custom datasets in examples/rag/finetune.py

* add test for custom dataset in distributed rag retriever
This commit is contained in:
Quentin Lhoest
2020-10-19 19:42:45 +02:00
committed by GitHub
parent a09fe140c1
commit 033f29c625
13 changed files with 663 additions and 98 deletions

View File

@@ -27,13 +27,18 @@ class RagPyTorchDistributedRetriever(RagRetriever):
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
"""
_init_retrieval = False
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
super().__init__(
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
)
self.process_group = None