[RAG] Add Ray implementation for distributed retrieval (#9197)

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* uncomment

* uncomment

* wip

* updates

* add docstring

* updates

* fix arg

* fixes

* add unit tests

* update readme

* update readme

* update finetune script

* update test

* add test

* add ray to test dependencies

* separate ray and ray tune

* formatting

* shutdown ray at end of test

* fix tests

* formatting

* formatting

* even more formatting

* address comments

* formatting

* add files

* Update examples/research_projects/rag/test_distributed_retriever.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* addressing comments

Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Amog Kamsetty
2020-12-21 01:39:30 -08:00
committed by GitHub
parent f38c4ad302
commit a4b21cdd20
14 changed files with 561 additions and 56 deletions

View File

@@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
```
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
## Document Retrieval
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
with [`Ray`](https://docs.ray.io/en/master/).
This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
By default this flag is set to `pytorch`.
For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
to collect the inputs from the other training workers and send back the corresponding document embeddings.
For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
Also make sure to start the Ray cluster before running fine-tuning.
```bash
# Start a single-node Ray cluster.
ray start --head
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
--distributed_retriever ray \
--num_retrieval_workers 4
# Stop the ray cluster once fine-tuning has finished.
ray stop
```
Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.
# Evaluation
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.