[RAG] Add Ray implementation for distributed retrieval (#9197)
* wip * wip * wip * wip * wip * wip * wip * wip * uncomment * uncomment * wip * updates * add docstring * updates * fix arg * fixes * add unit tests * update readme * update readme * update finetune script * update test * add test * add ray to test dependencies * separate ray and ray tune * formatting * shutdown ray at end of test * fix tests * formatting * formatting * even more formatting * address comments * formatting * add files * Update examples/research_projects/rag/test_distributed_retriever.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address comments * addressing comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -13,15 +13,27 @@ from datasets import Dataset
|
||||
import faiss
|
||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.integrations import is_ray_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
|
||||
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
if is_torch_available():
|
||||
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
else:
|
||||
RagPyTorchDistributedRetriever = None
|
||||
|
||||
if is_ray_available():
|
||||
import ray # noqa: E402 # isort:skip
|
||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip
|
||||
else:
|
||||
ray = None
|
||||
RagRayDistributedRetriever = None
|
||||
RayRetriever = None
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
@@ -32,8 +44,8 @@ def require_distributed_retrieval(test_case):
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
||||
if not (is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@@ -144,7 +156,31 @@ class RagRetrieverTest(TestCase):
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever:
|
||||
# Have to run in local mode because sys.path modifications at top of
|
||||
# file are not propogated to remote workers.
|
||||
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||
ray.init(local_mode=True)
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
workers = [remote_cls.remote() for _ in range(1)]
|
||||
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = self.get_dummy_dataset()
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval()
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
@@ -175,13 +211,51 @@ class RagRetrieverTest(TestCase):
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool):
|
||||
# Have to run in local mode because sys.path modifications at top of
|
||||
# file are not propogated to remote workers.
|
||||
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||
ray.init(local_mode=True)
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
workers = [remote_cls.remote() for _ in range(1)]
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
index=CustomHFIndex.load_from_disk(
|
||||
vector_size=config.retrieval_vector_size,
|
||||
dataset_path=config.passages_path,
|
||||
index_path=config.index_path,
|
||||
),
|
||||
)
|
||||
else:
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval()
|
||||
return retriever
|
||||
|
||||
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None:
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
@@ -192,33 +266,76 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_pytorch_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
|
||||
@require_ray
|
||||
def test_ray_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
@require_ray
|
||||
def test_custom_hf_index_ray_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
@require_ray
|
||||
def test_custom_ray_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user