Allow Custom Dataset in RAG Retriever (#7763)
* add CustomHFIndex * typo in config * update tests * add custom dataset example * clean script * update test data * minor in test * docs * docs * style * fix imports * allow to pass the indexed dataset directly * update tests * use multiset DPR * address thom and patrick's comments * style * update dpr tokenizer * add output_dir flag in use_own_knowledge_dataset.py * allow custom datasets in examples/rag/finetune.py * add test for custom dataset in distributed rag retriever
This commit is contained in:
@@ -13,7 +13,7 @@ import faiss
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.retrieval_rag import RagRetriever
|
||||
from transformers.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.testing_utils import (
|
||||
require_datasets,
|
||||
require_faiss,
|
||||
@@ -103,7 +103,7 @@ class RagRetrieverTest(TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_hf_index_retriever(self):
|
||||
def get_dummy_dataset(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
@@ -113,6 +113,10 @@ class RagRetrieverTest(TestCase):
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
return dataset
|
||||
|
||||
def get_dummy_canonical_hf_index_retriever(self):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
@@ -127,6 +131,35 @@ class RagRetrieverTest(TestCase):
|
||||
)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
else:
|
||||
retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
return retriever
|
||||
|
||||
def get_dummy_legacy_index_retriever(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
@@ -152,16 +185,15 @@ class RagRetrieverTest(TestCase):
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="legacy",
|
||||
index_path=self.tmpdirname,
|
||||
passages_path=self.tmpdirname,
|
||||
)
|
||||
retriever = RagRetriever(
|
||||
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
|
||||
)
|
||||
return retriever
|
||||
|
||||
def test_hf_index_retriever_retrieve(self):
|
||||
def test_canonical_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
@@ -174,10 +206,73 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
def test_canonical_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = self.get_dummy_dataset()
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
def test_custom_hf_index_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self):
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
def test_legacy_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
@@ -194,6 +289,18 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_legacy_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
@require_sentencepiece
|
||||
@@ -201,7 +308,7 @@ class RagRetrieverTest(TestCase):
|
||||
import torch
|
||||
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
question_input_ids = [[5, 7], [10, 11]]
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
|
||||
Reference in New Issue
Block a user