Rag end2end new (#17650)
* check * update the RAG-end2end with new PL and RAY * removed unwanted comments
This commit is contained in:
@@ -15,6 +15,10 @@ This code can be modified to experiment with other research on retrival augmente
|
||||
|
||||
To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used.
|
||||
|
||||
# Latest Update
|
||||
|
||||
⚠️ Updated the rag-end2end-retriever to be compatible with PL==1.6.4 and RAY==1.13.0 (latest versions to the date 2022-June-11)
|
||||
|
||||
# Note
|
||||
|
||||
⚠️ This project should be run with pytorch-lightning==1.3.1 which has a potential security vulnerability
|
||||
@@ -22,12 +26,14 @@ To start training, use the bash script (finetune_rag_ray_end2end.sh) in this fol
|
||||
# Testing
|
||||
|
||||
The following two bash scripts can be used to quickly test the implementation.
|
||||
1. sh ./test_run/test_rag_new_features.sh
|
||||
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
|
||||
- This is sufficient to check the model's ability to use the set functions correctly.
|
||||
2. sh ./test_run/test_finetune.sh script
|
||||
1. sh ./test_run/test_finetune.sh script
|
||||
- Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory).
|
||||
- Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning.
|
||||
- Please read the comments in the test_finetune.sh file.
|
||||
2. sh ./test_run/test_rag_new_features.sh
|
||||
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
|
||||
- This is sufficient to check the model's ability to use the set functions correctly.
|
||||
|
||||
|
||||
|
||||
# Comparison of end2end RAG (including DPR finetuning) VS original-RAG
|
||||
|
||||
@@ -41,7 +41,7 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=1,
|
||||
every_n_val_epochs=1, # works only with PL > 1.3
|
||||
every_n_epochs=1, # works only with PL > 1.3
|
||||
)
|
||||
|
||||
return checkpoint_callback
|
||||
|
||||
@@ -350,6 +350,7 @@ class GenerativeQAModule(BaseTransformer):
|
||||
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
|
||||
logger.info("done updating the dataset")
|
||||
|
||||
# To Do (@Aaron) : Useful in the future dynamic memory implementation.
|
||||
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
|
||||
# logger.info("then updating the index")
|
||||
# shutil.copy(self.custom_config.temp_index, self.config.idex_path)
|
||||
@@ -360,10 +361,7 @@ class GenerativeQAModule(BaseTransformer):
|
||||
|
||||
isEmUpdateBusy = False
|
||||
isAddIndexBusy = False
|
||||
|
||||
self.trainer.accelerator_connector.accelerator.barrier(
|
||||
"barrier"
|
||||
) # waint untill the index and kb get re-initialized.
|
||||
self.trainer.strategy.barrier("barrier")
|
||||
|
||||
loss_tensors = self._step(batch)
|
||||
|
||||
@@ -724,7 +722,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
|
||||
# Connect to an existing Ray cluster.
|
||||
try:
|
||||
ray.init(address=args.ray_address)
|
||||
ray.init(address=args.ray_address, namespace="rag")
|
||||
except (ConnectionError, ValueError):
|
||||
logger.warning(
|
||||
"Connection to Ray cluster failed. Make sure a Ray"
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from transformers import (
|
||||
@@ -386,24 +385,22 @@ def generic_train(
|
||||
|
||||
train_params = {}
|
||||
|
||||
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||
if args.fp16:
|
||||
train_params["precision"] = 16
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["accelerator"] = "ddp"
|
||||
train_params["accelerator"] = "auto"
|
||||
train_params["strategy"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
# train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None)
|
||||
train_params["profiler"] = None
|
||||
train_params["devices"] = "auto"
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
|
||||
logger=logger,
|
||||
plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version
|
||||
val_check_interval=1,
|
||||
num_sanity_val_steps=2,
|
||||
**train_params,
|
||||
@@ -412,6 +409,6 @@ def generic_train(
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
|
||||
# else:
|
||||
# print("RAG modeling tests with new set functions successfuly executed!")
|
||||
else:
|
||||
print("RAG modeling tests with new set functions successfuly executed!")
|
||||
return trainer
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
faiss-cpu >= 1.7.0
|
||||
datasets >= 1.6.2
|
||||
psutil >= 5.7.0
|
||||
torch >= 1.4.0
|
||||
pytorch-lightning
|
||||
faiss-cpu >= 1.7.2
|
||||
datasets
|
||||
psutil >= 5.9.1
|
||||
torch >= 1.11.0
|
||||
pytorch-lightning == 1.6.4
|
||||
nvidia-ml-py3 == 7.352.0
|
||||
ray >= 1.3.0
|
||||
ray >= 1.13.0
|
||||
@@ -44,11 +44,14 @@ python finetune_rag.py \
|
||||
--num_retrieval_workers 4 \
|
||||
--index_name custom \
|
||||
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
|
||||
--index_gpus 1 \
|
||||
--gpu_order [6,7,8,9,0,1,2,3,5,4] \
|
||||
--index_gpus 2 \
|
||||
--gpu_order [2,3,4,5,6,7,8,9,0,1] \
|
||||
--indexing_freq 5
|
||||
|
||||
|
||||
|
||||
# Stop the Ray cluster.
|
||||
ray stop
|
||||
|
||||
#CUDA_VISIBLE_DEVICES=2,3,4,5,6,7,8,9,0,1 sh ./test_run/test_finetune.sh
|
||||
#Make sure --gpu_order is same.
|
||||
Reference in New Issue
Block a user