[RAG] Add Ray implementation for distributed retrieval (#9197)
* wip * wip * wip * wip * wip * wip * wip * wip * uncomment * uncomment * wip * updates * add docstring * updates * fix arg * fixes * add unit tests * update readme * update readme * update finetune script * update test * add test * add ray to test dependencies * separate ray and ray tune * formatting * shutdown ray at end of test * fix tests * formatting * formatting * even more formatting * address comments * formatting * add files * Update examples/research_projects/rag/test_distributed_retriever.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address comments * addressing comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -19,3 +19,4 @@ pytest
|
|||||||
conllu
|
conllu
|
||||||
sentencepiece != 0.1.92
|
sentencepiece != 0.1.92
|
||||||
protobuf
|
protobuf
|
||||||
|
ray
|
||||||
|
|||||||
@@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
|
|||||||
```
|
```
|
||||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||||
|
|
||||||
|
## Document Retrieval
|
||||||
|
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
|
||||||
|
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
|
||||||
|
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
|
||||||
|
with [`Ray`](https://docs.ray.io/en/master/).
|
||||||
|
|
||||||
|
This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
|
||||||
|
By default this flag is set to `pytorch`.
|
||||||
|
|
||||||
|
For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
|
||||||
|
to collect the inputs from the other training workers and send back the corresponding document embeddings.
|
||||||
|
|
||||||
|
For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
|
||||||
|
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
|
||||||
|
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
|
||||||
|
Also make sure to start the Ray cluster before running fine-tuning.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start a single-node Ray cluster.
|
||||||
|
ray start --head
|
||||||
|
|
||||||
|
python examples/rag/finetune_rag.py \
|
||||||
|
--data_dir $DATA_DIR \
|
||||||
|
--output_dir $OUTPUT_DIR \
|
||||||
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
|
--model_type rag_sequence \
|
||||||
|
--fp16 \
|
||||||
|
--gpus 8
|
||||||
|
--distributed_retriever ray \
|
||||||
|
--num_retrieval_workers 4
|
||||||
|
|
||||||
|
# Stop the ray cluster once fine-tuning has finished.
|
||||||
|
ray stop
|
||||||
|
```
|
||||||
|
|
||||||
|
Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
|
||||||
|
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
|
||||||
|
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
|
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from transformers.file_utils import is_apex_available
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
|
require_ray,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
)
|
)
|
||||||
@@ -29,7 +30,7 @@ class RagFinetuneExampleTests(TestCasePlus):
|
|||||||
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
def _run_finetune(self, gpus: int):
|
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
@@ -66,6 +67,7 @@ class RagFinetuneExampleTests(TestCasePlus):
|
|||||||
--gradient_accumulation_steps 1 \
|
--gradient_accumulation_steps 1 \
|
||||||
--distributed-port 8787 \
|
--distributed-port 8787 \
|
||||||
--use_dummy_dataset 1 \
|
--use_dummy_dataset 1 \
|
||||||
|
--distributed_retriever {distributed_retriever} \
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if gpus > 0:
|
if gpus > 0:
|
||||||
@@ -94,3 +96,15 @@ class RagFinetuneExampleTests(TestCasePlus):
|
|||||||
def test_finetune_multigpu(self):
|
def test_finetune_multigpu(self):
|
||||||
result = self._run_finetune(gpus=2)
|
result = self._run_finetune(gpus=2)
|
||||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_ray
|
||||||
|
def test_finetune_gpu_ray_retrieval(self):
|
||||||
|
result = self._run_finetune(gpus=1, distributed_retriever="ray")
|
||||||
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@require_ray
|
||||||
|
def test_finetune_multigpu_ray_retrieval(self):
|
||||||
|
result = self._run_finetune(gpus=1, distributed_retriever="ray")
|
||||||
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||||
|
|||||||
@@ -31,14 +31,13 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
|||||||
If specified, use this index instead of the one built using the configuration
|
If specified, use this index instead of the one built using the configuration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_init_retrieval = False
|
|
||||||
|
|
||||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||||
generator_tokenizer=generator_tokenizer,
|
generator_tokenizer=generator_tokenizer,
|
||||||
index=index,
|
index=index,
|
||||||
|
init_retrieval=False,
|
||||||
)
|
)
|
||||||
self.process_group = None
|
self.process_group = None
|
||||||
|
|
||||||
154
examples/research_projects/rag/distributed_ray_retriever.py
Normal file
154
examples/research_projects/rag/distributed_ray_retriever.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||||
|
from transformers.file_utils import requires_datasets, requires_faiss
|
||||||
|
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RayRetriever:
|
||||||
|
def __init__(self):
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
|
||||||
|
if not self.initialized:
|
||||||
|
self.retriever = RagRetriever(
|
||||||
|
config,
|
||||||
|
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||||
|
generator_tokenizer=generator_tokenizer,
|
||||||
|
index=index,
|
||||||
|
init_retrieval=False,
|
||||||
|
)
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def init_retrieval(self):
|
||||||
|
self.retriever.index.init_index()
|
||||||
|
|
||||||
|
def retrieve(self, question_hidden_states, n_docs):
|
||||||
|
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
|
||||||
|
return doc_ids, retrieved_doc_embeds
|
||||||
|
|
||||||
|
|
||||||
|
class RagRayDistributedRetriever(RagRetriever):
|
||||||
|
"""
|
||||||
|
A distributed retriever built on top of the ``Ray`` API, a library
|
||||||
|
for building distributed applications (https://docs.ray.io/en/master/).
|
||||||
|
package. During training, all training workers initialize their own
|
||||||
|
instance of a `RagRayDistributedRetriever`, and each instance of
|
||||||
|
this distributed retriever shares a common set of Retrieval Ray
|
||||||
|
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
|
||||||
|
-classes-actors) that load the index on separate processes. Ray
|
||||||
|
handles the communication between the `RagRayDistributedRetriever`
|
||||||
|
instances and the remote Ray actors. If training is done in a
|
||||||
|
non-distributed setup, the index will simply be loaded in the same
|
||||||
|
process as the training worker and Ray will not be used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.RagConfig`):
|
||||||
|
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||||
|
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||||
|
The tokenizer that was used to tokenize the question.
|
||||||
|
It is used to decode the question and then use the generator_tokenizer.
|
||||||
|
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||||
|
The tokenizer used for the generator part of the RagModel.
|
||||||
|
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
|
||||||
|
These actor classes run on remote processes and are responsible for performing the index lookup.
|
||||||
|
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||||
|
If specified, use this index instead of the one built using the configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
|
||||||
|
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"When using Ray for distributed fine-tuning, "
|
||||||
|
"you'll need to provide the paths instead, "
|
||||||
|
"as the dataset and the index are loaded "
|
||||||
|
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
config,
|
||||||
|
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||||
|
generator_tokenizer=generator_tokenizer,
|
||||||
|
index=index,
|
||||||
|
init_retrieval=False,
|
||||||
|
)
|
||||||
|
self.retrieval_workers = retrieval_workers
|
||||||
|
if len(self.retrieval_workers) > 0:
|
||||||
|
ray.get(
|
||||||
|
[
|
||||||
|
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
|
||||||
|
for worker in self.retrieval_workers
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_retrieval(self):
|
||||||
|
"""
|
||||||
|
Retriever initialization function, needs to be called from the
|
||||||
|
training process. This function triggers retrieval initialization
|
||||||
|
for all retrieval actors if using distributed setting, or loads
|
||||||
|
index into current process if training is not distributed.
|
||||||
|
"""
|
||||||
|
logger.info("initializing retrieval")
|
||||||
|
|
||||||
|
if len(self.retrieval_workers) > 0:
|
||||||
|
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
|
||||||
|
else:
|
||||||
|
# Non-distributed training. Load index into this same process.
|
||||||
|
self.index.init_index()
|
||||||
|
|
||||||
|
def retrieve(self, question_hidden_states, n_docs):
|
||||||
|
"""
|
||||||
|
Retrieves documents for specified ``question_hidden_states``. If
|
||||||
|
running training with multiple workers, a random retrieval actor is
|
||||||
|
selected to perform the index lookup and return the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||||
|
A batch of query vectors to retrieve with.
|
||||||
|
n_docs (:obj:`int`):
|
||||||
|
The number of docs retrieved per query.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||||
|
The retrieval embeddings of the retrieved docs per query.
|
||||||
|
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||||
|
The ids of the documents in the index
|
||||||
|
doc_dicts (:obj:`List[dict]`):
|
||||||
|
The retrieved_doc_embeds examples per query.
|
||||||
|
"""
|
||||||
|
if len(self.retrieval_workers) > 0:
|
||||||
|
# Select a random retrieval actor.
|
||||||
|
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
|
||||||
|
doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs))
|
||||||
|
else:
|
||||||
|
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||||
|
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
|
||||||
|
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
|
||||||
|
requires_datasets(cls)
|
||||||
|
requires_faiss(cls)
|
||||||
|
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||||
|
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||||
|
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||||
|
generator_tokenizer = rag_tokenizer.generator
|
||||||
|
if indexed_dataset is not None:
|
||||||
|
config.index_name = "custom"
|
||||||
|
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
|
||||||
|
else:
|
||||||
|
index = cls._build_index(config)
|
||||||
|
return cls(
|
||||||
|
config,
|
||||||
|
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||||
|
generator_tokenizer=generator_tokenizer,
|
||||||
|
retrieval_workers=actor_handles,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
@@ -29,6 +29,12 @@ from transformers import (
|
|||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
from transformers.integrations import is_ray_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_ray_available():
|
||||||
|
import ray
|
||||||
|
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
||||||
|
|
||||||
|
|
||||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||||
@@ -36,7 +42,8 @@ from callbacks_rag import ( # noqa: E402 # isort:skipq
|
|||||||
get_early_stopping_callback,
|
get_early_stopping_callback,
|
||||||
Seq2SeqLoggingCallback,
|
Seq2SeqLoggingCallback,
|
||||||
)
|
)
|
||||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
|
||||||
|
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||||
from utils_rag import ( # noqa: E402 # isort:skip
|
from utils_rag import ( # noqa: E402 # isort:skip
|
||||||
calculate_exact_match,
|
calculate_exact_match,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
@@ -88,7 +95,12 @@ class CustomAccel(DDPAccelerator):
|
|||||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||||
if module.is_rag_model:
|
if module.is_rag_model:
|
||||||
|
if module.distributed_retriever == "pytorch":
|
||||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||||
|
elif module.distributed_retriever == "ray" and global_rank == 0:
|
||||||
|
# For the Ray retriever, only initialize it once when global
|
||||||
|
# rank is 0.
|
||||||
|
module.model.rag.retriever.init_retrieval()
|
||||||
|
|
||||||
|
|
||||||
class GenerativeQAModule(BaseTransformer):
|
class GenerativeQAModule(BaseTransformer):
|
||||||
@@ -127,7 +139,13 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
config.generator.prefix = hparams.prefix
|
config.generator.prefix = hparams.prefix
|
||||||
config.label_smoothing = hparams.label_smoothing
|
config.label_smoothing = hparams.label_smoothing
|
||||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||||
|
if hparams.distributed_retriever == "pytorch":
|
||||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||||
|
elif hparams.distributed_retriever == "ray":
|
||||||
|
# The Ray retriever needs the handles to the retriever actors.
|
||||||
|
retriever = RagRayDistributedRetriever.from_pretrained(
|
||||||
|
hparams.model_name_or_path, hparams.actor_handles, config=config
|
||||||
|
)
|
||||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||||
prefix = config.question_encoder.prefix
|
prefix = config.question_encoder.prefix
|
||||||
else:
|
else:
|
||||||
@@ -180,8 +198,13 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
# For single GPU training, init_ddp_connection is not called.
|
# For single GPU training, init_ddp_connection is not called.
|
||||||
# So we need to initialize the retrievers here.
|
# So we need to initialize the retrievers here.
|
||||||
if hparams.gpus <= 1:
|
if hparams.gpus <= 1:
|
||||||
|
if hparams.distributed_retriever == "ray":
|
||||||
|
self.model.retriever.init_retrieval()
|
||||||
|
elif hparams.distributed_retriever == "pytorch":
|
||||||
self.model.retriever.init_retrieval(self.distributed_port)
|
self.model.retriever.init_retrieval(self.distributed_port)
|
||||||
|
|
||||||
|
self.distributed_retriever = hparams.distributed_retriever
|
||||||
|
|
||||||
def forward(self, input_ids, **kwargs):
|
def forward(self, input_ids, **kwargs):
|
||||||
return self.model(input_ids, **kwargs)
|
return self.model(input_ids, **kwargs)
|
||||||
|
|
||||||
@@ -420,6 +443,7 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
type=str,
|
type=str,
|
||||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -442,12 +466,58 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distributed_retriever",
|
||||||
|
choices=["ray", "pytorch"],
|
||||||
|
type=str,
|
||||||
|
default="pytorch",
|
||||||
|
help="What implementation to use for distributed retriever? If "
|
||||||
|
"pytorch is selected, the index is loaded on training "
|
||||||
|
"worker 0, and torch.distributed is used to handle "
|
||||||
|
"communication between training worker 0, and the other "
|
||||||
|
"training workers. If ray is selected, the Ray library is "
|
||||||
|
"used to create load the index on separate processes, "
|
||||||
|
"and Ray handles the communication between the training "
|
||||||
|
"workers and the retrieval actors.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_dummy_dataset",
|
"--use_dummy_dataset",
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_retrieval_workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The number of retrieval actors to use when Ray is selected"
|
||||||
|
"for the distributed retriever. Has no effect when "
|
||||||
|
"distributed_retriever is set to pytorch.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_ray_specific_args(parser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_retrieval_workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The number of retrieval actors to use when Ray is selected"
|
||||||
|
"for the distributed retriever. Has no effect when "
|
||||||
|
"distributed_retriever is set to pytorch.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ray cluster address.
|
||||||
|
parser.add_argument(
|
||||||
|
"--ray-address",
|
||||||
|
default="auto",
|
||||||
|
type=str,
|
||||||
|
help="The address of the Ray cluster to connect to. If not "
|
||||||
|
"specified, Ray will attempt to automatically detect the "
|
||||||
|
"cluster. Has no effect if pytorch is used as the distributed "
|
||||||
|
"retriever.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -461,6 +531,46 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
|||||||
args = args or parser.parse_args()
|
args = args or parser.parse_args()
|
||||||
|
|
||||||
Path(args.output_dir).mkdir(exist_ok=True)
|
Path(args.output_dir).mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
named_actors = []
|
||||||
|
if args.distributed_retriever == "ray" and args.gpus > 1:
|
||||||
|
if not is_ray_available():
|
||||||
|
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
|
||||||
|
# Connect to an existing Ray cluster.
|
||||||
|
try:
|
||||||
|
ray.init(address=args.ray_address)
|
||||||
|
except (ConnectionError, ValueError):
|
||||||
|
logger.warning(
|
||||||
|
"Connection to Ray cluster failed. Make sure a Ray"
|
||||||
|
"cluster is running by either using Ray's cluster "
|
||||||
|
"launcher (`ray up`) or by manually starting Ray on "
|
||||||
|
"each node via `ray start --head` for the head node "
|
||||||
|
"and `ray start --address='<ip address>:6379'` for "
|
||||||
|
"additional nodes. See "
|
||||||
|
"https://docs.ray.io/en/master/cluster/index.html "
|
||||||
|
"for more info."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create Ray actors only for rank 0.
|
||||||
|
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
|
||||||
|
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
|
||||||
|
):
|
||||||
|
remote_cls = ray.remote(RayRetriever)
|
||||||
|
named_actors = [
|
||||||
|
remote_cls.options(name="retrieval_worker_{}".format(i)).remote()
|
||||||
|
for i in range(args.num_retrieval_workers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
|
||||||
|
os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)]
|
||||||
|
args.actor_handles = named_actors
|
||||||
|
assert args.actor_handles == named_actors
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||||
|
|
||||||
@@ -471,17 +581,17 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
|||||||
or str(args.output_dir).startswith("/tmp")
|
or str(args.output_dir).startswith("/tmp")
|
||||||
or str(args.output_dir).startswith("/var")
|
or str(args.output_dir).startswith("/var")
|
||||||
):
|
):
|
||||||
logger = True # don't pollute wandb logs unnecessarily
|
training_logger = True # don't pollute wandb logs unnecessarily
|
||||||
elif args.logger_name == "wandb":
|
elif args.logger_name == "wandb":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
training_logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||||
|
|
||||||
elif args.logger_name == "wandb_shared":
|
elif args.logger_name == "wandb_shared":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||||
|
|
||||||
es_callback = (
|
es_callback = (
|
||||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||||
@@ -495,8 +605,9 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
|||||||
logging_callback=Seq2SeqLoggingCallback(),
|
logging_callback=Seq2SeqLoggingCallback(),
|
||||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||||
early_stopping_callback=es_callback,
|
early_stopping_callback=es_callback,
|
||||||
logger=logger,
|
logger=training_logger,
|
||||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||||
|
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
||||||
)
|
)
|
||||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||||
|
|
||||||
@@ -509,4 +620,19 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||||
|
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||||
|
parser = GenerativeQAModule.add_ray_specific_args(parser)
|
||||||
|
|
||||||
|
# Pytorch Lightning Profiler
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile",
|
||||||
|
action="store_true",
|
||||||
|
help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
# run ./examples/rag/finetune_rag.sh --help to see all the possible options
|
||||||
|
|
||||||
python examples/rag/finetune_rag.py \
|
python examples/rag/finetune_rag.py \
|
||||||
--data_dir $DATA_DIR \
|
--data_dir $DATA_DIR \
|
||||||
@@ -11,10 +11,10 @@ python examples/rag/finetune_rag.py \
|
|||||||
--model_type rag_sequence \
|
--model_type rag_sequence \
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--gpus 8 \
|
--gpus 8 \
|
||||||
|
--profile \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--n_val -1 \
|
--n_val -1 \
|
||||||
--val_check_interval 0.25 \
|
|
||||||
--train_batch_size 8 \
|
--train_batch_size 8 \
|
||||||
--eval_batch_size 1 \
|
--eval_batch_size 1 \
|
||||||
--max_source_length 128 \
|
--max_source_length 128 \
|
||||||
@@ -31,4 +31,4 @@ python examples/rag/finetune_rag.py \
|
|||||||
--learning_rate 3e-05 \
|
--learning_rate 3e-05 \
|
||||||
--num_train_epochs 100 \
|
--num_train_epochs 100 \
|
||||||
--warmup_steps 500 \
|
--warmup_steps 500 \
|
||||||
--gradient_accumulation_steps 1
|
--gradient_accumulation_steps 1 \
|
||||||
|
|||||||
44
examples/research_projects/rag/finetune_rag_ray.sh
Executable file
44
examples/research_projects/rag/finetune_rag_ray.sh
Executable file
@@ -0,0 +1,44 @@
|
|||||||
|
# Sample script to finetune RAG using Ray for distributed retrieval.
|
||||||
|
|
||||||
|
# Add parent directory to python path to access lightning_base.py
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
|
# Start a single-node Ray cluster.
|
||||||
|
ray start --head
|
||||||
|
|
||||||
|
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||||
|
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
|
||||||
|
|
||||||
|
python examples/rag/finetune_rag.py \
|
||||||
|
--data_dir $DATA_DIR \
|
||||||
|
--output_dir $OUTPUT_DIR \
|
||||||
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
|
--model_type rag_sequence \
|
||||||
|
--fp16 \
|
||||||
|
--gpus 8 \
|
||||||
|
--profile \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--n_val -1 \
|
||||||
|
--train_batch_size 8 \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--max_source_length 128 \
|
||||||
|
--max_target_length 25 \
|
||||||
|
--val_max_target_length 25 \
|
||||||
|
--test_max_target_length 25 \
|
||||||
|
--label_smoothing 0.1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--attention_dropout 0.1 \
|
||||||
|
--weight_decay 0.001 \
|
||||||
|
--adam_epsilon 1e-08 \
|
||||||
|
--max_grad_norm 0.1 \
|
||||||
|
--lr_scheduler polynomial \
|
||||||
|
--learning_rate 3e-05 \
|
||||||
|
--num_train_epochs 100 \
|
||||||
|
--warmup_steps 500 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--distributed_retriever ray \
|
||||||
|
--num_retrieval_workers 4
|
||||||
|
|
||||||
|
# Stop the Ray cluster.
|
||||||
|
ray stop
|
||||||
@@ -13,15 +13,27 @@ from datasets import Dataset
|
|||||||
import faiss
|
import faiss
|
||||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||||
|
from transformers.integrations import is_ray_available
|
||||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||||
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
|
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||||
|
|
||||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
if is_torch_available():
|
||||||
|
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||||
|
else:
|
||||||
|
RagPyTorchDistributedRetriever = None
|
||||||
|
|
||||||
|
if is_ray_available():
|
||||||
|
import ray # noqa: E402 # isort:skip
|
||||||
|
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip
|
||||||
|
else:
|
||||||
|
ray = None
|
||||||
|
RagRayDistributedRetriever = None
|
||||||
|
RayRetriever = None
|
||||||
|
|
||||||
|
|
||||||
def require_distributed_retrieval(test_case):
|
def require_distributed_retrieval(test_case):
|
||||||
@@ -32,8 +44,8 @@ def require_distributed_retrieval(test_case):
|
|||||||
These tests are skipped when respective libraries are not installed.
|
These tests are skipped when respective libraries are not installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
if not (is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case)
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
@@ -144,7 +156,31 @@ class RagRetrieverTest(TestCase):
|
|||||||
retriever.init_retrieval(port)
|
retriever.init_retrieval(port)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever:
|
||||||
|
# Have to run in local mode because sys.path modifications at top of
|
||||||
|
# file are not propogated to remote workers.
|
||||||
|
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||||
|
ray.init(local_mode=True)
|
||||||
|
config = RagConfig(
|
||||||
|
retrieval_vector_size=self.retrieval_vector_size,
|
||||||
|
question_encoder=DPRConfig().to_dict(),
|
||||||
|
generator=BartConfig().to_dict(),
|
||||||
|
)
|
||||||
|
remote_cls = ray.remote(RayRetriever)
|
||||||
|
workers = [remote_cls.remote() for _ in range(1)]
|
||||||
|
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||||
|
mock_load_dataset.return_value = self.get_dummy_dataset()
|
||||||
|
retriever = RagRayDistributedRetriever(
|
||||||
|
config,
|
||||||
|
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||||
|
generator_tokenizer=self.get_bart_tokenizer(),
|
||||||
|
retrieval_workers=workers,
|
||||||
|
)
|
||||||
|
if init_retrieval:
|
||||||
|
retriever.init_retrieval()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||||
dataset = self.get_dummy_dataset()
|
dataset = self.get_dummy_dataset()
|
||||||
config = RagConfig(
|
config = RagConfig(
|
||||||
retrieval_vector_size=self.retrieval_vector_size,
|
retrieval_vector_size=self.retrieval_vector_size,
|
||||||
@@ -175,13 +211,51 @@ class RagRetrieverTest(TestCase):
|
|||||||
retriever.init_retrieval(port)
|
retriever.init_retrieval(port)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
@require_torch_non_multi_gpu_but_fix_me
|
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool):
|
||||||
def test_pytorch_distributed_retriever_retrieve(self):
|
# Have to run in local mode because sys.path modifications at top of
|
||||||
n_docs = 1
|
# file are not propogated to remote workers.
|
||||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||||
hidden_states = np.array(
|
ray.init(local_mode=True)
|
||||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
dataset = self.get_dummy_dataset()
|
||||||
|
config = RagConfig(
|
||||||
|
retrieval_vector_size=self.retrieval_vector_size,
|
||||||
|
question_encoder=DPRConfig().to_dict(),
|
||||||
|
generator=BartConfig().to_dict(),
|
||||||
|
index_name="custom",
|
||||||
)
|
)
|
||||||
|
remote_cls = ray.remote(RayRetriever)
|
||||||
|
workers = [remote_cls.remote() for _ in range(1)]
|
||||||
|
if from_disk:
|
||||||
|
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||||
|
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||||
|
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||||
|
dataset.drop_index("embeddings")
|
||||||
|
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||||
|
del dataset
|
||||||
|
retriever = RagRayDistributedRetriever(
|
||||||
|
config,
|
||||||
|
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||||
|
generator_tokenizer=self.get_bart_tokenizer(),
|
||||||
|
retrieval_workers=workers,
|
||||||
|
index=CustomHFIndex.load_from_disk(
|
||||||
|
vector_size=config.retrieval_vector_size,
|
||||||
|
dataset_path=config.passages_path,
|
||||||
|
index_path=config.index_path,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
retriever = RagRayDistributedRetriever(
|
||||||
|
config,
|
||||||
|
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||||
|
generator_tokenizer=self.get_bart_tokenizer(),
|
||||||
|
retrieval_workers=workers,
|
||||||
|
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||||
|
)
|
||||||
|
if init_retrieval:
|
||||||
|
retriever.init_retrieval()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None:
|
||||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||||
self.assertEqual(len(doc_dicts), 2)
|
self.assertEqual(len(doc_dicts), 2)
|
||||||
@@ -192,33 +266,76 @@ class RagRetrieverTest(TestCase):
|
|||||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||||
|
|
||||||
@require_torch_non_multi_gpu_but_fix_me
|
@require_torch_non_multi_gpu_but_fix_me
|
||||||
def test_custom_hf_index_retriever_retrieve(self):
|
def test_pytorch_distributed_retriever_retrieve(self):
|
||||||
n_docs = 1
|
n_docs = 1
|
||||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
|
||||||
hidden_states = np.array(
|
hidden_states = np.array(
|
||||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||||
)
|
)
|
||||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
|
||||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
self.distributed_retriever_check(
|
||||||
self.assertEqual(len(doc_dicts), 2)
|
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
)
|
||||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
|
||||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
@require_torch_non_multi_gpu_but_fix_me
|
||||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
def test_custom_hf_index_pytorch_retriever_retrieve(self):
|
||||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
n_docs = 1
|
||||||
|
hidden_states = np.array(
|
||||||
|
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
self.distributed_retriever_check(
|
||||||
|
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False),
|
||||||
|
hidden_states,
|
||||||
|
n_docs,
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch_non_multi_gpu_but_fix_me
|
@require_torch_non_multi_gpu_but_fix_me
|
||||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||||
n_docs = 1
|
n_docs = 1
|
||||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
|
||||||
hidden_states = np.array(
|
hidden_states = np.array(
|
||||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||||
)
|
)
|
||||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
|
||||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
self.distributed_retriever_check(
|
||||||
self.assertEqual(len(doc_dicts), 2)
|
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True),
|
||||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
hidden_states,
|
||||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
n_docs,
|
||||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
)
|
||||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
|
||||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
@require_ray
|
||||||
|
def test_ray_distributed_retriever_retrieve(self):
|
||||||
|
n_docs = 1
|
||||||
|
hidden_states = np.array(
|
||||||
|
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
self.distributed_retriever_check(
|
||||||
|
self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||||
|
)
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
@require_ray
|
||||||
|
def test_custom_hf_index_ray_retriever_retrieve(self):
|
||||||
|
n_docs = 1
|
||||||
|
hidden_states = np.array(
|
||||||
|
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.distributed_retriever_check(
|
||||||
|
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False),
|
||||||
|
hidden_states,
|
||||||
|
n_docs,
|
||||||
|
)
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
@require_ray
|
||||||
|
def test_custom_ray_distributed_retriever_retrieve_from_disk(self):
|
||||||
|
n_docs = 1
|
||||||
|
hidden_states = np.array(
|
||||||
|
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
self.distributed_retriever_check(
|
||||||
|
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs
|
||||||
|
)
|
||||||
|
ray.shutdown()
|
||||||
|
|||||||
@@ -219,6 +219,7 @@ from .integrations import ( # isort:skip
|
|||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_available,
|
||||||
|
is_ray_tune_available,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,8 +63,16 @@ try:
|
|||||||
import ray # noqa: F401
|
import ray # noqa: F401
|
||||||
|
|
||||||
_has_ray = True
|
_has_ray = True
|
||||||
|
try:
|
||||||
|
# Ray Tune has additional dependencies.
|
||||||
|
from ray import tune # noqa: F401
|
||||||
|
|
||||||
|
_has_ray_tune = True
|
||||||
|
except (ImportError):
|
||||||
|
_has_ray_tune = False
|
||||||
except (ImportError):
|
except (ImportError):
|
||||||
_has_ray = False
|
_has_ray = False
|
||||||
|
_has_ray_tune = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||||
@@ -127,6 +135,10 @@ def is_ray_available():
|
|||||||
return _has_ray
|
return _has_ray
|
||||||
|
|
||||||
|
|
||||||
|
def is_ray_tune_available():
|
||||||
|
return _has_ray_tune
|
||||||
|
|
||||||
|
|
||||||
def is_azureml_available():
|
def is_azureml_available():
|
||||||
return _has_azureml
|
return _has_azureml
|
||||||
|
|
||||||
@@ -143,7 +155,7 @@ def hp_params(trial):
|
|||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
if isinstance(trial, optuna.Trial):
|
if isinstance(trial, optuna.Trial):
|
||||||
return trial.params
|
return trial.params
|
||||||
if is_ray_available():
|
if is_ray_tune_available():
|
||||||
if isinstance(trial, dict):
|
if isinstance(trial, dict):
|
||||||
return trial
|
return trial
|
||||||
|
|
||||||
@@ -153,7 +165,7 @@ def hp_params(trial):
|
|||||||
def default_hp_search_backend():
|
def default_hp_search_backend():
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
return "optuna"
|
return "optuna"
|
||||||
elif is_ray_available():
|
elif is_ray_tune_available():
|
||||||
return "ray"
|
return "ray"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -370,9 +370,8 @@ class RagRetriever:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_init_retrieval = True
|
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
|
||||||
|
self._init_retrieval = init_retrieval
|
||||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
|
||||||
requires_datasets(self)
|
requires_datasets(self)
|
||||||
requires_faiss(self)
|
requires_faiss(self)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from .integrations import ( # isort: split
|
|||||||
is_fairscale_available,
|
is_fairscale_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_tune_available,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
run_hp_search_optuna,
|
run_hp_search_optuna,
|
||||||
@@ -145,7 +145,7 @@ if is_mlflow_available():
|
|||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
if is_ray_available():
|
if is_ray_tune_available():
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
|
||||||
if is_azureml_available():
|
if is_azureml_available():
|
||||||
@@ -1062,7 +1062,7 @@ class Trainer:
|
|||||||
backend = HPSearchBackend(backend)
|
backend = HPSearchBackend(backend)
|
||||||
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
||||||
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
||||||
if backend == HPSearchBackend.RAY and not is_ray_available():
|
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -132,9 +132,9 @@ def default_hp_space_optuna(trial) -> Dict[str, float]:
|
|||||||
|
|
||||||
|
|
||||||
def default_hp_space_ray(trial) -> Dict[str, float]:
|
def default_hp_space_ray(trial) -> Dict[str, float]:
|
||||||
from .integrations import is_ray_available
|
from .integrations import is_ray_tune_available
|
||||||
|
|
||||||
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
|
assert is_ray_tune_available(), "This function needs ray installed: `pip " "install ray[tune]`"
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user