[RAG] Add Ray implementation for distributed retrieval (#9197)
* wip * wip * wip * wip * wip * wip * wip * wip * uncomment * uncomment * wip * updates * add docstring * updates * fix arg * fixes * add unit tests * update readme * update readme * update finetune script * update test * add test * add ray to test dependencies * separate ray and ray tune * formatting * shutdown ray at end of test * fix tests * formatting * formatting * even more formatting * address comments * formatting * add files * Update examples/research_projects/rag/test_distributed_retriever.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address comments * addressing comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -19,3 +19,4 @@ pytest
|
||||
conllu
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
ray
|
||||
|
||||
@@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
|
||||
```
|
||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||
|
||||
## Document Retrieval
|
||||
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
|
||||
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
|
||||
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
|
||||
with [`Ray`](https://docs.ray.io/en/master/).
|
||||
|
||||
This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
|
||||
By default this flag is set to `pytorch`.
|
||||
|
||||
For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
|
||||
to collect the inputs from the other training workers and send back the corresponding document embeddings.
|
||||
|
||||
For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
|
||||
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
|
||||
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
|
||||
Also make sure to start the Ray cluster before running fine-tuning.
|
||||
|
||||
```bash
|
||||
# Start a single-node Ray cluster.
|
||||
ray start --head
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
--distributed_retriever ray \
|
||||
--num_retrieval_workers 4
|
||||
|
||||
# Stop the ray cluster once fine-tuning has finished.
|
||||
ray stop
|
||||
```
|
||||
|
||||
Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
|
||||
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
|
||||
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.
|
||||
|
||||
# Evaluation
|
||||
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
require_ray,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
@@ -29,7 +30,7 @@ class RagFinetuneExampleTests(TestCasePlus):
|
||||
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
def _run_finetune(self, gpus: int):
|
||||
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
@@ -66,6 +67,7 @@ class RagFinetuneExampleTests(TestCasePlus):
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed-port 8787 \
|
||||
--use_dummy_dataset 1 \
|
||||
--distributed_retriever {distributed_retriever} \
|
||||
""".split()
|
||||
|
||||
if gpus > 0:
|
||||
@@ -94,3 +96,15 @@ class RagFinetuneExampleTests(TestCasePlus):
|
||||
def test_finetune_multigpu(self):
|
||||
result = self._run_finetune(gpus=2)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_ray
|
||||
def test_finetune_gpu_ray_retrieval(self):
|
||||
result = self._run_finetune(gpus=1, distributed_retriever="ray")
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_ray
|
||||
def test_finetune_multigpu_ray_retrieval(self):
|
||||
result = self._run_finetune(gpus=1, distributed_retriever="ray")
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@@ -31,14 +31,13 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
super().__init__(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
154
examples/research_projects/rag/distributed_ray_retriever.py
Normal file
154
examples/research_projects/rag/distributed_ray_retriever.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import ray
|
||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||
from transformers.file_utils import requires_datasets, requires_faiss
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayRetriever:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
|
||||
if not self.initialized:
|
||||
self.retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
self.initialized = True
|
||||
|
||||
def init_retrieval(self):
|
||||
self.retriever.index.init_index()
|
||||
|
||||
def retrieve(self, question_hidden_states, n_docs):
|
||||
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
|
||||
return doc_ids, retrieved_doc_embeds
|
||||
|
||||
|
||||
class RagRayDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``Ray`` API, a library
|
||||
for building distributed applications (https://docs.ray.io/en/master/).
|
||||
package. During training, all training workers initialize their own
|
||||
instance of a `RagRayDistributedRetriever`, and each instance of
|
||||
this distributed retriever shares a common set of Retrieval Ray
|
||||
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
|
||||
-classes-actors) that load the index on separate processes. Ray
|
||||
handles the communication between the `RagRayDistributedRetriever`
|
||||
instances and the remote Ray actors. If training is done in a
|
||||
non-distributed setup, the index will simply be loaded in the same
|
||||
process as the training worker and Ray will not be used.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
|
||||
These actor classes run on remote processes and are responsible for performing the index lookup.
|
||||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
|
||||
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
|
||||
raise ValueError(
|
||||
"When using Ray for distributed fine-tuning, "
|
||||
"you'll need to provide the paths instead, "
|
||||
"as the dataset and the index are loaded "
|
||||
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
|
||||
)
|
||||
super().__init__(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
self.retrieval_workers = retrieval_workers
|
||||
if len(self.retrieval_workers) > 0:
|
||||
ray.get(
|
||||
[
|
||||
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
|
||||
for worker in self.retrieval_workers
|
||||
]
|
||||
)
|
||||
|
||||
def init_retrieval(self):
|
||||
"""
|
||||
Retriever initialization function, needs to be called from the
|
||||
training process. This function triggers retrieval initialization
|
||||
for all retrieval actors if using distributed setting, or loads
|
||||
index into current process if training is not distributed.
|
||||
"""
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
if len(self.retrieval_workers) > 0:
|
||||
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
|
||||
else:
|
||||
# Non-distributed training. Load index into this same process.
|
||||
self.index.init_index()
|
||||
|
||||
def retrieve(self, question_hidden_states, n_docs):
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``. If
|
||||
running training with multiple workers, a random retrieval actor is
|
||||
selected to perform the index lookup and return the result.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Output:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
if len(self.retrieval_workers) > 0:
|
||||
# Select a random retrieval actor.
|
||||
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
|
||||
doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs))
|
||||
else:
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
@classmethod
|
||||
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
|
||||
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
|
||||
requires_datasets(cls)
|
||||
requires_faiss(cls)
|
||||
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||
generator_tokenizer = rag_tokenizer.generator
|
||||
if indexed_dataset is not None:
|
||||
config.index_name = "custom"
|
||||
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
|
||||
else:
|
||||
index = cls._build_index(config)
|
||||
return cls(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
retrieval_workers=actor_handles,
|
||||
index=index,
|
||||
)
|
||||
@@ -29,6 +29,12 @@ from transformers import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
from transformers.integrations import is_ray_available
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
||||
|
||||
|
||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
@@ -36,7 +42,8 @@ from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
get_early_stopping_callback,
|
||||
Seq2SeqLoggingCallback,
|
||||
)
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
|
||||
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from utils_rag import ( # noqa: E402 # isort:skip
|
||||
calculate_exact_match,
|
||||
flatten_list,
|
||||
@@ -88,7 +95,12 @@ class CustomAccel(DDPAccelerator):
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if module.is_rag_model:
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
if module.distributed_retriever == "pytorch":
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
elif module.distributed_retriever == "ray" and global_rank == 0:
|
||||
# For the Ray retriever, only initialize it once when global
|
||||
# rank is 0.
|
||||
module.model.rag.retriever.init_retrieval()
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
@@ -127,7 +139,13 @@ class GenerativeQAModule(BaseTransformer):
|
||||
config.generator.prefix = hparams.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
if hparams.distributed_retriever == "pytorch":
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
elif hparams.distributed_retriever == "ray":
|
||||
# The Ray retriever needs the handles to the retriever actors.
|
||||
retriever = RagRayDistributedRetriever.from_pretrained(
|
||||
hparams.model_name_or_path, hparams.actor_handles, config=config
|
||||
)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
@@ -180,7 +198,12 @@ class GenerativeQAModule(BaseTransformer):
|
||||
# For single GPU training, init_ddp_connection is not called.
|
||||
# So we need to initialize the retrievers here.
|
||||
if hparams.gpus <= 1:
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
if hparams.distributed_retriever == "ray":
|
||||
self.model.retriever.init_retrieval()
|
||||
elif hparams.distributed_retriever == "pytorch":
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
self.distributed_retriever = hparams.distributed_retriever
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
@@ -420,6 +443,7 @@ class GenerativeQAModule(BaseTransformer):
|
||||
type=str,
|
||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
@@ -442,12 +466,58 @@ class GenerativeQAModule(BaseTransformer):
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed_retriever",
|
||||
choices=["ray", "pytorch"],
|
||||
type=str,
|
||||
default="pytorch",
|
||||
help="What implementation to use for distributed retriever? If "
|
||||
"pytorch is selected, the index is loaded on training "
|
||||
"worker 0, and torch.distributed is used to handle "
|
||||
"communication between training worker 0, and the other "
|
||||
"training workers. If ray is selected, the Ray library is "
|
||||
"used to create load the index on separate processes, "
|
||||
"and Ray handles the communication between the training "
|
||||
"workers and the retrieval actors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy_dataset",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_retrieval_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of retrieval actors to use when Ray is selected"
|
||||
"for the distributed retriever. Has no effect when "
|
||||
"distributed_retriever is set to pytorch.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_ray_specific_args(parser):
|
||||
parser.add_argument(
|
||||
"--num_retrieval_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of retrieval actors to use when Ray is selected"
|
||||
"for the distributed retriever. Has no effect when "
|
||||
"distributed_retriever is set to pytorch.",
|
||||
)
|
||||
|
||||
# Ray cluster address.
|
||||
parser.add_argument(
|
||||
"--ray-address",
|
||||
default="auto",
|
||||
type=str,
|
||||
help="The address of the Ray cluster to connect to. If not "
|
||||
"specified, Ray will attempt to automatically detect the "
|
||||
"cluster. Has no effect if pytorch is used as the distributed "
|
||||
"retriever.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -461,6 +531,46 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
args = args or parser.parse_args()
|
||||
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
|
||||
named_actors = []
|
||||
if args.distributed_retriever == "ray" and args.gpus > 1:
|
||||
if not is_ray_available():
|
||||
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
|
||||
# Connect to an existing Ray cluster.
|
||||
try:
|
||||
ray.init(address=args.ray_address)
|
||||
except (ConnectionError, ValueError):
|
||||
logger.warning(
|
||||
"Connection to Ray cluster failed. Make sure a Ray"
|
||||
"cluster is running by either using Ray's cluster "
|
||||
"launcher (`ray up`) or by manually starting Ray on "
|
||||
"each node via `ray start --head` for the head node "
|
||||
"and `ray start --address='<ip address>:6379'` for "
|
||||
"additional nodes. See "
|
||||
"https://docs.ray.io/en/master/cluster/index.html "
|
||||
"for more info."
|
||||
)
|
||||
raise
|
||||
|
||||
# Create Ray actors only for rank 0.
|
||||
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
|
||||
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
|
||||
):
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
named_actors = [
|
||||
remote_cls.options(name="retrieval_worker_{}".format(i)).remote()
|
||||
for i in range(args.num_retrieval_workers)
|
||||
]
|
||||
else:
|
||||
logger.info(
|
||||
"Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
|
||||
os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]
|
||||
)
|
||||
)
|
||||
named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)]
|
||||
args.actor_handles = named_actors
|
||||
assert args.actor_handles == named_actors
|
||||
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
|
||||
@@ -471,17 +581,17 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
training_logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
training_logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
es_callback = (
|
||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
@@ -495,8 +605,9 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
logger=training_logger,
|
||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
||||
@@ -509,4 +620,19 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
parser = GenerativeQAModule.add_ray_specific_args(parser)
|
||||
|
||||
# Pytorch Lightning Profiler
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||
# run ./examples/rag/finetune_rag.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
@@ -11,10 +11,10 @@ python examples/rag/finetune_rag.py \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--profile \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 0.25 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
@@ -31,4 +31,4 @@ python examples/rag/finetune_rag.py \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1
|
||||
--gradient_accumulation_steps 1 \
|
||||
|
||||
44
examples/research_projects/rag/finetune_rag_ray.sh
Executable file
44
examples/research_projects/rag/finetune_rag_ray.sh
Executable file
@@ -0,0 +1,44 @@
|
||||
# Sample script to finetune RAG using Ray for distributed retrieval.
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# Start a single-node Ray cluster.
|
||||
ray start --head
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--profile \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed_retriever ray \
|
||||
--num_retrieval_workers 4
|
||||
|
||||
# Stop the Ray cluster.
|
||||
ray stop
|
||||
@@ -13,15 +13,27 @@ from datasets import Dataset
|
||||
import faiss
|
||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.integrations import is_ray_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
|
||||
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
if is_torch_available():
|
||||
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
else:
|
||||
RagPyTorchDistributedRetriever = None
|
||||
|
||||
if is_ray_available():
|
||||
import ray # noqa: E402 # isort:skip
|
||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip
|
||||
else:
|
||||
ray = None
|
||||
RagRayDistributedRetriever = None
|
||||
RayRetriever = None
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
@@ -32,8 +44,8 @@ def require_distributed_retrieval(test_case):
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
||||
if not (is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@@ -144,7 +156,31 @@ class RagRetrieverTest(TestCase):
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever:
|
||||
# Have to run in local mode because sys.path modifications at top of
|
||||
# file are not propogated to remote workers.
|
||||
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||
ray.init(local_mode=True)
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
workers = [remote_cls.remote() for _ in range(1)]
|
||||
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = self.get_dummy_dataset()
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval()
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
@@ -175,13 +211,51 @@ class RagRetrieverTest(TestCase):
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool):
|
||||
# Have to run in local mode because sys.path modifications at top of
|
||||
# file are not propogated to remote workers.
|
||||
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||
ray.init(local_mode=True)
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
workers = [remote_cls.remote() for _ in range(1)]
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
index=CustomHFIndex.load_from_disk(
|
||||
vector_size=config.retrieval_vector_size,
|
||||
dataset_path=config.passages_path,
|
||||
index_path=config.index_path,
|
||||
),
|
||||
)
|
||||
else:
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval()
|
||||
return retriever
|
||||
|
||||
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None:
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
@@ -192,33 +266,76 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_pytorch_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
|
||||
@require_ray
|
||||
def test_ray_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
@require_ray
|
||||
def test_custom_hf_index_ray_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
@require_ray
|
||||
def test_custom_ray_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user