[RAG] Add Ray implementation for distributed retrieval (#9197)

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* uncomment

* uncomment

* wip

* updates

* add docstring

* updates

* fix arg

* fixes

* add unit tests

* update readme

* update readme

* update finetune script

* update test

* add test

* add ray to test dependencies

* separate ray and ray tune

* formatting

* shutdown ray at end of test

* fix tests

* formatting

* formatting

* even more formatting

* address comments

* formatting

* add files

* Update examples/research_projects/rag/test_distributed_retriever.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* addressing comments

Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Amog Kamsetty
2020-12-21 01:39:30 -08:00
committed by GitHub
parent f38c4ad302
commit a4b21cdd20
14 changed files with 561 additions and 56 deletions

View File

@@ -13,15 +13,27 @@ from datasets import Dataset
import faiss
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
from transformers.integrations import is_ray_available
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.rag.retrieval_rag import CustomHFIndex
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
if is_torch_available():
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
else:
RagPyTorchDistributedRetriever = None
if is_ray_available():
import ray # noqa: E402 # isort:skip
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip
else:
ray = None
RagRayDistributedRetriever = None
RayRetriever = None
def require_distributed_retrieval(test_case):
@@ -32,8 +44,8 @@ def require_distributed_retrieval(test_case):
These tests are skipped when respective libraries are not installed.
"""
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
if not (is_datasets_available() and is_faiss_available() and is_psutil_available()):
test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case)
return test_case
@@ -144,7 +156,31 @@ class RagRetrieverTest(TestCase):
retriever.init_retrieval(port)
return retriever
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever:
# Have to run in local mode because sys.path modifications at top of
# file are not propogated to remote workers.
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
ray.init(local_mode=True)
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
)
remote_cls = ray.remote(RayRetriever)
workers = [remote_cls.remote() for _ in range(1)]
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = self.get_dummy_dataset()
retriever = RagRayDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
retrieval_workers=workers,
)
if init_retrieval:
retriever.init_retrieval()
return retriever
def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
@@ -175,13 +211,51 @@ class RagRetrieverTest(TestCase):
retriever.init_retrieval(port)
return retriever
@require_torch_non_multi_gpu_but_fix_me
def test_pytorch_distributed_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool):
# Have to run in local mode because sys.path modifications at top of
# file are not propogated to remote workers.
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
ray.init(local_mode=True)
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="custom",
)
remote_cls = ray.remote(RayRetriever)
workers = [remote_cls.remote() for _ in range(1)]
if from_disk:
config.passages_path = os.path.join(self.tmpdirname, "dataset")
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
dataset.drop_index("embeddings")
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
del dataset
retriever = RagRayDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
retrieval_workers=workers,
index=CustomHFIndex.load_from_disk(
vector_size=config.retrieval_vector_size,
dataset_path=config.passages_path,
index_path=config.index_path,
),
)
else:
retriever = RagRayDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
retrieval_workers=workers,
index=CustomHFIndex(config.retrieval_vector_size, dataset),
)
if init_retrieval:
retriever.init_retrieval()
return retriever
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None:
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
@@ -192,33 +266,76 @@ class RagRetrieverTest(TestCase):
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
@require_torch_non_multi_gpu_but_fix_me
def test_custom_hf_index_retriever_retrieve(self):
def test_pytorch_distributed_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
self.distributed_retriever_check(
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
)
@require_torch_non_multi_gpu_but_fix_me
def test_custom_hf_index_pytorch_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False),
hidden_states,
n_docs,
)
@require_torch_non_multi_gpu_but_fix_me
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True),
hidden_states,
n_docs,
)
@require_ray
def test_ray_distributed_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
self.distributed_retriever_check(
self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs
)
ray.shutdown()
@require_ray
def test_custom_hf_index_ray_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
with self.assertRaises(ValueError):
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False),
hidden_states,
n_docs,
)
ray.shutdown()
@require_ray
def test_custom_ray_distributed_retriever_retrieve_from_disk(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs
)
ray.shutdown()