[RAG] Add Ray implementation for distributed retrieval (#9197)
* wip * wip * wip * wip * wip * wip * wip * wip * uncomment * uncomment * wip * updates * add docstring * updates * fix arg * fixes * add unit tests * update readme * update readme * update finetune script * update test * add test * add ray to test dependencies * separate ray and ray tune * formatting * shutdown ray at end of test * fix tests * formatting * formatting * even more formatting * address comments * formatting * add files * Update examples/research_projects/rag/test_distributed_retriever.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address comments * addressing comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
|
||||
```
|
||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||
|
||||
## Document Retrieval
|
||||
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
|
||||
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
|
||||
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
|
||||
with [`Ray`](https://docs.ray.io/en/master/).
|
||||
|
||||
This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
|
||||
By default this flag is set to `pytorch`.
|
||||
|
||||
For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
|
||||
to collect the inputs from the other training workers and send back the corresponding document embeddings.
|
||||
|
||||
For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
|
||||
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
|
||||
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
|
||||
Also make sure to start the Ray cluster before running fine-tuning.
|
||||
|
||||
```bash
|
||||
# Start a single-node Ray cluster.
|
||||
ray start --head
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
--distributed_retriever ray \
|
||||
--num_retrieval_workers 4
|
||||
|
||||
# Stop the ray cluster once fine-tuning has finished.
|
||||
ray stop
|
||||
```
|
||||
|
||||
Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
|
||||
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
|
||||
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.
|
||||
|
||||
# Evaluation
|
||||
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
|
||||
|
||||
Reference in New Issue
Block a user