Fix rag finetuning + add finetuning test (#8585)
* replace init_ddp_connection for index init * style * add finetune test * add test data * move generate tensors to device * add test on EM metric * style * allow multi process test * keep gloo process group for retrieval * add multi-gpu test * use custom accelerator * clean test finetune * minor * style * style * typo * use python call instead of imported main fumction * return_dict fix in modeling_rag * use float32 in retrieval * store as float32 as well in the custom knowledge dataset example * style * rename to finetune_rag * style * update readme * rename utils and callbacks to utils_rag and callbacks_rag * fix test * patrick's comments * generate dummy data in the finetue test script * remove dummy data files * style
This commit is contained in:
@@ -384,6 +384,8 @@ def generic_train(
|
|||||||
train_params["distributed_backend"] = "ddp"
|
train_params["distributed_backend"] = "ddp"
|
||||||
|
|
||||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||||
|
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||||
|
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
||||||
|
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ to the retriever to extract relevant context documents. The documents are then p
|
|||||||
Such contextualized inputs are passed to the generator.
|
Such contextualized inputs are passed to the generator.
|
||||||
|
|
||||||
Read more about RAG at https://arxiv.org/abs/2005.11401.
|
Read more about RAG at https://arxiv.org/abs/2005.11401.
|
||||||
# Finetuning
|
|
||||||
|
|
||||||
|
# Finetuning
|
||||||
|
|
||||||
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
|
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
|
||||||
```bash
|
```bash
|
||||||
@@ -20,10 +20,10 @@ test.source
|
|||||||
test.target
|
test.target
|
||||||
```
|
```
|
||||||
|
|
||||||
A sample finetuning command (run ` ./examples/rag/finetune.py --help` to list all available options):
|
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/rag/finetune.py \
|
python examples/rag/finetune_rag.py \
|
||||||
--data_dir $DATA_DIR \
|
--data_dir $DATA_DIR \
|
||||||
--output_dir $OUTPUT_DIR \
|
--output_dir $OUTPUT_DIR \
|
||||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
@@ -45,7 +45,7 @@ python examples/rag/consolidate_rag_checkpoint.py \
|
|||||||
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
||||||
--dest path/to/checkpoint
|
--dest path/to/checkpoint
|
||||||
```
|
```
|
||||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune.py` script.
|
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||||
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
@@ -130,3 +130,29 @@ python examples/rag/eval_rag.py \
|
|||||||
--print_predictions \
|
--print_predictions \
|
||||||
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
|
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Use your own knowledge source
|
||||||
|
|
||||||
|
By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset.
|
||||||
|
With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG.
|
||||||
|
|
||||||
|
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
|
||||||
|
```bash
|
||||||
|
python examples/rag/use_own_knowledge_dataset.py \
|
||||||
|
--csv_path path/to/my_csv \
|
||||||
|
--output_dir path/to/my_knowledge_dataset \
|
||||||
|
```
|
||||||
|
|
||||||
|
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
|
||||||
|
```bash
|
||||||
|
python examples/rag/finetune_rag.py \
|
||||||
|
--data_dir $DATA_DIR \
|
||||||
|
--output_dir $OUTPUT_DIR \
|
||||||
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
|
--model_type rag_sequence \
|
||||||
|
--fp16 \
|
||||||
|
--gpus 8
|
||||||
|
--index_name custom
|
||||||
|
--passages_path path/to/data/my_knowledge_dataset
|
||||||
|
--index_path path/to/my_knowledge_dataset_hnsw_index.faiss
|
||||||
|
```
|
||||||
@@ -8,7 +8,7 @@ import torch
|
|||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
from pytorch_lightning.utilities import rank_zero_only
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
from utils import save_json
|
from utils_rag import save_json
|
||||||
|
|
||||||
|
|
||||||
def count_trainable_parameters(model):
|
def count_trainable_parameters(model):
|
||||||
@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
|
|||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="max",
|
mode="max",
|
||||||
save_top_k=3,
|
save_top_k=3,
|
||||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
return checkpoint_callback
|
return checkpoint_callback
|
||||||
|
|
||||||
@@ -40,7 +40,6 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
|||||||
generator_tokenizer=generator_tokenizer,
|
generator_tokenizer=generator_tokenizer,
|
||||||
index=index,
|
index=index,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = None
|
self.process_group = None
|
||||||
|
|
||||||
def init_retrieval(self, distributed_port: int):
|
def init_retrieval(self, distributed_port: int):
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
@@ -15,29 +13,31 @@ import numpy as np
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
|
||||||
|
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
|
BatchEncoding,
|
||||||
RagConfig,
|
RagConfig,
|
||||||
RagSequenceForGeneration,
|
RagSequenceForGeneration,
|
||||||
RagTokenForGeneration,
|
RagTokenForGeneration,
|
||||||
RagTokenizer,
|
RagTokenizer,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
get_linear_schedule_with_warmup,
|
|
||||||
)
|
)
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
from callbacks import ( # noqa: E402 # isort:skipq
|
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||||
get_checkpoint_callback,
|
get_checkpoint_callback,
|
||||||
get_early_stopping_callback,
|
get_early_stopping_callback,
|
||||||
Seq2SeqLoggingCallback,
|
Seq2SeqLoggingCallback,
|
||||||
)
|
)
|
||||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||||
from utils import ( # noqa: E402 # isort:skip
|
from utils_rag import ( # noqa: E402 # isort:skip
|
||||||
calculate_exact_match,
|
calculate_exact_match,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
get_git_info,
|
get_git_info,
|
||||||
@@ -67,6 +67,30 @@ class AttrDict(dict):
|
|||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
|
||||||
|
# is no longer used, and is moved into DDPAccelerator instead.
|
||||||
|
# We override DDPAccelerator to add our custom logic for initializing the
|
||||||
|
# retriever.
|
||||||
|
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAccel(DDPAccelerator):
|
||||||
|
def __init__(self, trainer=None, **kwargs):
|
||||||
|
# Trainer is set later.
|
||||||
|
super().__init__(trainer, **kwargs)
|
||||||
|
|
||||||
|
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||||
|
logger.info("Custom init_ddp_connection.")
|
||||||
|
module = self.trainer.model
|
||||||
|
if self.cluster_environment is None:
|
||||||
|
self.cluster_environment = TorchElasticEnvironment()
|
||||||
|
self.distributed_port = module.hparams.distributed_port
|
||||||
|
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||||
|
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||||
|
if module.is_rag_model:
|
||||||
|
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||||
|
|
||||||
|
|
||||||
class GenerativeQAModule(BaseTransformer):
|
class GenerativeQAModule(BaseTransformer):
|
||||||
mode = "generative_qa"
|
mode = "generative_qa"
|
||||||
loss_names = ["loss"]
|
loss_names = ["loss"]
|
||||||
@@ -91,23 +115,24 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||||
|
|
||||||
# set retriever parameters
|
# set retriever parameters
|
||||||
config.index_name = args.index_name or config.index_name
|
config.index_name = hparams.index_name or config.index_name
|
||||||
config.passages_path = args.passages_path or config.passages_path
|
config.passages_path = hparams.passages_path or config.passages_path
|
||||||
config.index_path = args.index_path or config.index_path
|
config.index_path = hparams.index_path or config.index_path
|
||||||
|
config.use_dummy_dataset = hparams.use_dummy_dataset
|
||||||
|
|
||||||
# set extra_model_params for generator configs and load_model
|
# set extra_model_params for generator configs and load_model
|
||||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||||
if self.is_rag_model:
|
if self.is_rag_model:
|
||||||
if args.prefix is not None:
|
if hparams.prefix is not None:
|
||||||
config.generator.prefix = args.prefix
|
config.generator.prefix = hparams.prefix
|
||||||
config.label_smoothing = hparams.label_smoothing
|
config.label_smoothing = hparams.label_smoothing
|
||||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||||
prefix = config.question_encoder.prefix
|
prefix = config.question_encoder.prefix
|
||||||
else:
|
else:
|
||||||
if args.prefix is not None:
|
if hparams.prefix is not None:
|
||||||
config.prefix = args.prefix
|
config.prefix = hparams.prefix
|
||||||
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
||||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
||||||
prefix = config.prefix
|
prefix = config.prefix
|
||||||
@@ -152,11 +177,9 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
self.num_workers = hparams.num_workers
|
self.num_workers = hparams.num_workers
|
||||||
self.distributed_port = self.hparams.distributed_port
|
self.distributed_port = self.hparams.distributed_port
|
||||||
|
|
||||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
# For single GPU training, init_ddp_connection is not called.
|
||||||
logger.info("Custom init_ddp_connection.")
|
# So we need to initialize the retrievers here.
|
||||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
if hparams.gpus <= 1:
|
||||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
|
||||||
if self.is_rag_model:
|
|
||||||
self.model.retriever.init_retrieval(self.distributed_port)
|
self.model.retriever.init_retrieval(self.distributed_port)
|
||||||
|
|
||||||
def forward(self, input_ids, **kwargs):
|
def forward(self, input_ids, **kwargs):
|
||||||
@@ -270,6 +293,7 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
|
|
||||||
def _generative_step(self, batch: dict) -> dict:
|
def _generative_step(self, batch: dict) -> dict:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
batch = BatchEncoding(batch).to(device=self.model.device)
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
batch["input_ids"],
|
batch["input_ids"],
|
||||||
attention_mask=batch["attention_mask"],
|
attention_mask=batch["attention_mask"],
|
||||||
@@ -322,17 +346,6 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
|
|
||||||
def train_dataloader(self) -> DataLoader:
|
def train_dataloader(self) -> DataLoader:
|
||||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||||
t_total = (
|
|
||||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
|
||||||
// self.hparams.accumulate_grad_batches
|
|
||||||
* float(self.hparams.max_epochs)
|
|
||||||
)
|
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
|
||||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
|
||||||
)
|
|
||||||
if max(scheduler.get_last_lr()) > 0:
|
|
||||||
warnings.warn("All learning rates are 0")
|
|
||||||
self.lr_scheduler = scheduler
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
def val_dataloader(self) -> DataLoader:
|
def val_dataloader(self) -> DataLoader:
|
||||||
@@ -429,10 +442,24 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_dummy_dataset",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main(args, model=None) -> GenerativeQAModule:
|
def main(args=None, model=None) -> GenerativeQAModule:
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||||
|
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||||
|
|
||||||
|
args = args or parser.parse_args()
|
||||||
|
|
||||||
Path(args.output_dir).mkdir(exist_ok=True)
|
Path(args.output_dir).mkdir(exist_ok=True)
|
||||||
if model is None:
|
if model is None:
|
||||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||||
@@ -461,6 +488,7 @@ def main(args, model=None) -> GenerativeQAModule:
|
|||||||
if args.early_stopping_patience >= 0
|
if args.early_stopping_patience >= 0
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer: pl.Trainer = generic_train(
|
trainer: pl.Trainer = generic_train(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
@@ -468,31 +496,17 @@ def main(args, model=None) -> GenerativeQAModule:
|
|||||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||||
early_stopping_callback=es_callback,
|
early_stopping_callback=es_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||||
)
|
)
|
||||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||||
|
|
||||||
if not args.do_predict:
|
if not args.do_predict:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
model.hparams.test_checkpoint = ""
|
|
||||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
|
||||||
if checkpoints:
|
|
||||||
model.hparams.test_checkpoint = checkpoints[-1]
|
|
||||||
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
|
|
||||||
trainer.logger.log_hyperparams(model.hparams)
|
|
||||||
|
|
||||||
# test() without a model tests using the best checkpoint automatically
|
# test() without a model tests using the best checkpoint automatically
|
||||||
trainer.test()
|
trainer.test()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
main()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
|
||||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
|
||||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
main(args)
|
|
||||||
@@ -4,7 +4,7 @@ export PYTHONPATH="../":"${PYTHONPATH}"
|
|||||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||||
|
|
||||||
python examples/rag/finetune.py \
|
python examples/rag/finetune_rag.py \
|
||||||
--data_dir $DATA_DIR \
|
--data_dir $DATA_DIR \
|
||||||
--output_dir $OUTPUT_DIR \
|
--output_dir $OUTPUT_DIR \
|
||||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||||
96
examples/rag/test_finetune_rag.py
Normal file
96
examples/rag/test_finetune_rag.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import finetune_rag
|
||||||
|
from transformers.file_utils import is_apex_available
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
TestCasePlus,
|
||||||
|
execute_subprocess_async,
|
||||||
|
require_torch_gpu,
|
||||||
|
require_torch_multi_gpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class RagFinetuneExampleTests(TestCasePlus):
|
||||||
|
def _create_dummy_data(self, data_dir):
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
contents = {"source": "What is love ?", "target": "life"}
|
||||||
|
n_lines = {"train": 12, "val": 2, "test": 2}
|
||||||
|
for split in ["train", "test", "val"]:
|
||||||
|
for field in ["source", "target"]:
|
||||||
|
content = "\n".join([contents[field]] * n_lines[split])
|
||||||
|
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
def _run_finetune(self, gpus: int):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
output_dir = os.path.join(tmp_dir, "output")
|
||||||
|
data_dir = os.path.join(tmp_dir, "data")
|
||||||
|
self._create_dummy_data(data_dir=data_dir)
|
||||||
|
|
||||||
|
testargs = f"""
|
||||||
|
--data_dir {data_dir} \
|
||||||
|
--output_dir {output_dir} \
|
||||||
|
--model_name_or_path facebook/rag-sequence-base \
|
||||||
|
--model_type rag_sequence \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--n_val -1 \
|
||||||
|
--val_check_interval 1.0 \
|
||||||
|
--train_batch_size 2 \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--max_source_length 25 \
|
||||||
|
--max_target_length 25 \
|
||||||
|
--val_max_target_length 25 \
|
||||||
|
--test_max_target_length 25 \
|
||||||
|
--label_smoothing 0.1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--attention_dropout 0.1 \
|
||||||
|
--weight_decay 0.001 \
|
||||||
|
--adam_epsilon 1e-08 \
|
||||||
|
--max_grad_norm 0.1 \
|
||||||
|
--lr_scheduler polynomial \
|
||||||
|
--learning_rate 3e-04 \
|
||||||
|
--num_train_epochs 1 \
|
||||||
|
--warmup_steps 4 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--distributed-port 8787 \
|
||||||
|
--use_dummy_dataset 1 \
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
if gpus > 0:
|
||||||
|
testargs.append(f"--gpus={gpus}")
|
||||||
|
if is_apex_available():
|
||||||
|
testargs.append("--fp16")
|
||||||
|
else:
|
||||||
|
testargs.append("--gpus=0")
|
||||||
|
testargs.append("--distributed_backend=ddp_cpu")
|
||||||
|
testargs.append("--num_processes=2")
|
||||||
|
|
||||||
|
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
|
||||||
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
|
||||||
|
metrics_save_path = os.path.join(output_dir, "metrics.json")
|
||||||
|
with open(metrics_save_path) as f:
|
||||||
|
result = json.load(f)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_finetune_gpu(self):
|
||||||
|
result = self._run_finetune(gpus=1)
|
||||||
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_finetune_multigpu(self):
|
||||||
|
result = self._run_finetune(gpus=2)
|
||||||
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||||
@@ -7,7 +7,7 @@ from tempfile import TemporaryDirectory
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import Features, Sequence, Value, load_dataset
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -82,10 +82,14 @@ def main(
|
|||||||
# And compute the embeddings
|
# And compute the embeddings
|
||||||
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
||||||
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
||||||
|
new_features = Features(
|
||||||
|
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
||||||
|
) # optional, save as float32 instead of float64 to save space
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
||||||
batched=True,
|
batched=True,
|
||||||
batch_size=processing_args.batch_size,
|
batch_size=processing_args.batch_size,
|
||||||
|
features=new_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
# And finally save your dataset
|
# And finally save your dataset
|
||||||
|
|||||||
@@ -556,7 +556,9 @@ class RagModel(RagPreTrainedModel):
|
|||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
|
|
||||||
if has_to_retrieve:
|
if has_to_retrieve:
|
||||||
question_enc_outputs = self.question_encoder(input_ids, attention_mask=attention_mask)
|
question_enc_outputs = self.question_encoder(
|
||||||
|
input_ids, attention_mask=attention_mask, return_dict=True
|
||||||
|
)
|
||||||
question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
|
question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
|
||||||
|
|
||||||
retriever_outputs = self.retriever(
|
retriever_outputs = self.retriever(
|
||||||
@@ -616,6 +618,7 @@ class RagModel(RagPreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not has_to_retrieve:
|
if not has_to_retrieve:
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ class HFIndexBase(Index):
|
|||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self._index_initialized = index_initialized
|
self._index_initialized = index_initialized
|
||||||
self._check_dataset_format(with_index=index_initialized)
|
self._check_dataset_format(with_index=index_initialized)
|
||||||
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
|
||||||
|
|
||||||
def _check_dataset_format(self, with_index: bool):
|
def _check_dataset_format(self, with_index: bool):
|
||||||
if not isinstance(self.dataset, Dataset):
|
if not isinstance(self.dataset, Dataset):
|
||||||
|
|||||||
Reference in New Issue
Block a user