Allow Custom Dataset in RAG Retriever (#7763)
* add CustomHFIndex * typo in config * update tests * add custom dataset example * clean script * update test data * minor in test * docs * docs * style * fix imports * allow to pass the indexed dataset directly * update tests * use multiset DPR * address thom and patrick's comments * style * update dpr tokenizer * add output_dir flag in use_own_knowledge_dataset.py * allow custom datasets in examples/rag/finetune.py * add test for custom dataset in distributed rag retriever
This commit is contained in:
@@ -27,13 +27,18 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
super().__init__(
|
||||
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
)
|
||||
|
||||
self.process_group = None
|
||||
|
||||
Reference in New Issue
Block a user