RAG (#6813)
* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867 * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
This commit is contained in:
@@ -366,6 +366,8 @@ def generic_train(
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import datasets
|
||||
import faiss
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import faiss
|
||||
import transformers
|
||||
from eli5_utils import (
|
||||
embed_questions_for_retrieval,
|
||||
|
||||
@@ -5,7 +5,6 @@ from random import choice, randint
|
||||
from time import time
|
||||
|
||||
import datasets # noqa: F401
|
||||
import faiss # noqa: F401
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -15,6 +14,7 @@ from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
import faiss # noqa: F401
|
||||
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
|
||||
88
examples/rag/README.md
Normal file
88
examples/rag/README.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# Intro
|
||||
RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator.
|
||||
During a forward pass, we encode the input with the question encoder and pass it
|
||||
to the retriever to extract relevant context documents. The documents are then prepended to the input.
|
||||
Such contextualized inputs is passed to the generator.
|
||||
|
||||
The question encoder can be any `autoencoding` model, preferably :obj:`~transformers.DPRQuestionEncoder`, and the generator can be any `seq2seq` model, preferably :obj:`~transformers.BartForConditionalGeneration`.
|
||||
|
||||
The model can be initialized with a :obj:`~transformers.RagRetriever` for end-to-end generation or used in combination with the outputs of a retriever in multiple steps - see examples for more details.
|
||||
The model is compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as the ``generator``.
|
||||
The model has been tested with :class:`~transformers.DPRQuestionEncoder` as the ``question_encoder`` and :class:`~transformers.BartForConditionalGeneration` or :class:`~transformers.T5ForConditionalGeneration` as the ``generator``.
|
||||
|
||||
RAG models were released with the paper `Retrieval-Augmented Generation for
|
||||
Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
|
||||
|
||||
|
||||
# Finetuning
|
||||
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq).
|
||||
Follow instructions there regarding data preprocessing. A sample finetuning command:
|
||||
|
||||
```
|
||||
python examples/rag/finetune.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
```
|
||||
|
||||
|
||||
# Evaluation
|
||||
Apart from the parameters specifying the model to evaluate and some extra parameters, the evaluation script expects paths to two files:
|
||||
- `evaluation_set` - a path to a file specifying the evaluation dataset, a single datapoint per line, e.g.
|
||||
```who is the owner of reading football club```
|
||||
- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`.
|
||||
|
||||
We expect the following formats of the gold data file:
|
||||
|
||||
- for e2e evaluation, we support two formats of the gold file:
|
||||
- `qa` - where a single line in the following format: input [tab] output_list, e.g.:
|
||||
```
|
||||
who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai']
|
||||
```
|
||||
- `ans` - where a single line of the gold file contains the expected output string, e.g.:
|
||||
```
|
||||
Xiu Li Dai
|
||||
```
|
||||
|
||||
- for retrieval evaluation, we expect a tab-separated list of Wikipedia page titles constituting positive contexts for a given query, e.g. given a question `who sings does he love me with reba`, a line with ground truth retrieval data could look as follows:
|
||||
```
|
||||
Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album)
|
||||
```
|
||||
|
||||
## Retrieval evaluation
|
||||
|
||||
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
|
||||
|
||||
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
|
||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||
```
|
||||
python examples/rag/parse_dpr_relevance_data.py --src_path path/to/unziped/biencoder-nq-dev.json --evaluation_set path/to/output/biencoder-nq-dev.questions --gold_data_path path/to/output/biencoder-nq-dev.pages
|
||||
```
|
||||
3. Run evaluation:
|
||||
```
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \ # model name or path of the model we're evaluating
|
||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
||||
--predictions_path path/to/retrieval_preds.tsv \ # name of file in which predictions will be stored
|
||||
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
|
||||
--recalculate # if predictions_filename already exists, and this option is set - we regenerate the answers, otherwise we reuse the predicsion file to calculate metrics.
|
||||
```
|
||||
|
||||
|
||||
## End-to-end evaluation
|
||||
```
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set path/to/test.source \
|
||||
--gold_data_path path/to/gold_data \
|
||||
--predictions_path path/to/e2e_preds.txt \
|
||||
--eval_mode e2e \ # indicates whether we're performing retrieval evaluation or e2e evaluation (default)
|
||||
--n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time
|
||||
--print_predictions
|
||||
```
|
||||
0
examples/rag/__init__.py
Normal file
0
examples/rag/__init__.py
Normal file
30
examples/rag/callbacks.py
Normal file
30
examples/rag/callbacks.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric):
|
||||
"""Saves the best model by validation EM score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
elif metric == "em":
|
||||
exp = "{val_avg_em:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=3,
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
135
examples/rag/distributed_retriever.py
Normal file
135
examples/rag/distributed_retriever.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers import RagRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
|
||||
initalize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
|
||||
in cpu memory. The index will also work well in a non-distributed setup.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||
super().__init__(
|
||||
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
||||
)
|
||||
|
||||
self.process_group = None
|
||||
|
||||
def init_retrieval(self, distributed_port: int):
|
||||
"""
|
||||
Retriever initalization function, needs to be called from the training process. The function sets some common parameters
|
||||
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
|
||||
|
||||
Args:
|
||||
distributed_port (:obj:`int`):
|
||||
The port on which the main communication of the training run is carried out. We set the port for retrieval-related
|
||||
communication as ``distributed_port + 1``.
|
||||
"""
|
||||
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
# initializing a separate process group for retrievel as the default
|
||||
# nccl backend doesn't support gather/scatter operations while gloo
|
||||
# is too slow to replace nccl for the core gpu communication
|
||||
if dist.is_initialized():
|
||||
logger.info("dist initialized")
|
||||
# needs to be set manually
|
||||
os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname()
|
||||
# avoid clash with the NCCL port
|
||||
os.environ["MASTER_PORT"] = str(distributed_port + 1)
|
||||
self.process_group = dist.new_group(ranks=None, backend="gloo")
|
||||
|
||||
# initialize retriever only on the main worker
|
||||
if not dist.is_initialized() or self._is_main():
|
||||
logger.info("dist not initialized / main")
|
||||
self.index.init_index()
|
||||
|
||||
# all processes wait untill the retriever is initialized by the main process
|
||||
if dist.is_initialized():
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
def _is_main(self):
|
||||
return dist.get_rank(group=self.process_group) == 0
|
||||
|
||||
def _scattered(self, scatter_list, target_shape, target_type=torch.float32):
|
||||
target_tensor = torch.empty(target_shape, dtype=target_type)
|
||||
dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group)
|
||||
return target_tensor
|
||||
|
||||
def _infer_socket_ifname(self):
|
||||
addrs = psutil.net_if_addrs()
|
||||
# a hacky way to deal with varying network interface names
|
||||
ifname = next((addr for addr in addrs if addr.startswith("e")), None)
|
||||
return ifname
|
||||
|
||||
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
|
||||
from all the processes in the main training process group, performs the retrieval and scatters back the results.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Ouput:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
|
||||
# single GPU training
|
||||
if not dist.is_initialized():
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
# distributed training
|
||||
world_size = dist.get_world_size(group=self.process_group)
|
||||
|
||||
# gather logic
|
||||
gather_list = None
|
||||
if self._is_main():
|
||||
gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
|
||||
dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group)
|
||||
|
||||
# scatter logic
|
||||
n_queries = question_hidden_states.shape[0]
|
||||
scatter_ids = []
|
||||
scatter_vectors = []
|
||||
if self._is_main():
|
||||
assert len(gather_list) == world_size
|
||||
ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs)
|
||||
ids, vectors = torch.tensor(ids), torch.tensor(vectors)
|
||||
scatter_ids = self._chunk_tensor(ids, n_queries)
|
||||
scatter_vectors = self._chunk_tensor(vectors, n_queries)
|
||||
doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
|
||||
retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]])
|
||||
|
||||
return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids)
|
||||
310
examples/rag/eval_rag.py
Normal file
310
examples/rag/eval_rag.py
Normal file
@@ -0,0 +1,310 @@
|
||||
""" Evaluation script for RAG models."""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
||||
from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
def infer_model_type(model_name_or_path):
|
||||
if "token" in model_name_or_path:
|
||||
return "rag_token"
|
||||
if "sequence" in model_name_or_path:
|
||||
return "rag_sequence"
|
||||
if "bart" in model_name_or_path:
|
||||
return "bart"
|
||||
return None
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
return max(metric_fn(prediction, gt) for gt in ground_truths)
|
||||
|
||||
|
||||
def get_scores(args, preds_path, gold_data_path):
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
answers = []
|
||||
|
||||
if args.gold_data_mode == "qa":
|
||||
data = pd.read_csv(gold_data_path, sep="\t", header=None)
|
||||
for answer_list in data[1]:
|
||||
ground_truths = ast.literal_eval(answer_list)
|
||||
answers.append(ground_truths)
|
||||
else:
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
answers = [[reference] for reference in references]
|
||||
|
||||
f1 = em = total = 0
|
||||
for prediction, ground_truths in zip(hypos, answers):
|
||||
total += 1
|
||||
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
em = 100.0 * em / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
logger.info(f"F1: {f1:.2f}")
|
||||
logger.info(f"EM: {em:.2f}")
|
||||
|
||||
|
||||
def get_precision_at_k(args, preds_path, gold_data_path):
|
||||
k = args.k
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
|
||||
em = total = 0
|
||||
for hypo, reference in zip(hypos, references):
|
||||
hypo_provenance = set(hypo.split("\t")[:k])
|
||||
ref_provenance = set(reference.split("\t")[1 : (k + 1)])
|
||||
total += 1
|
||||
em += len(hypo_provenance & ref_provenance) / k
|
||||
|
||||
em = 100.0 * em / total
|
||||
logger.info(f"Precision@{k}: {em: .2f}")
|
||||
|
||||
|
||||
def evaluate_batch_retrieval(args, rag_model, questions):
|
||||
def strip_title(title):
|
||||
if title.startswith('"'):
|
||||
title = title[1:]
|
||||
if title.endswith('"'):
|
||||
title = title[:-1]
|
||||
return title
|
||||
|
||||
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)["input_ids"].to(args.device)
|
||||
|
||||
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids, return_dict=True)
|
||||
question_enc_pool_output = question_enc_outputs.pooler_output
|
||||
|
||||
result = rag_model.retriever(
|
||||
retriever_input_ids,
|
||||
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=rag_model.rag.generator.config.prefix,
|
||||
n_docs=rag_model.config.n_docs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
|
||||
provenance_strings = []
|
||||
for docs in all_docs:
|
||||
provenance = [strip_title(title) for title in docs["title"]]
|
||||
provenance_strings.append("\t".join(provenance))
|
||||
return provenance_strings
|
||||
|
||||
|
||||
def evaluate_batch_e2e(args, rag_model, questions):
|
||||
with torch.no_grad():
|
||||
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions, return_tensors="pt", padding=True, truncation=True
|
||||
)["input_ids"].to(args.device)
|
||||
outputs = rag_model.generate( # rag_model overwrites generate
|
||||
input_ids,
|
||||
num_beams=args.num_beams,
|
||||
min_length=args.min_length,
|
||||
max_length=args.max_length,
|
||||
early_stopping=False,
|
||||
num_return_sequences=1,
|
||||
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
||||
clean_up_tokenization=True,
|
||||
print_docs=args.print_docs,
|
||||
)
|
||||
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
if args.print_predictions:
|
||||
for q, a in zip(questions, answers):
|
||||
logger.info("Q: {} - A: {}".format(q, a))
|
||||
|
||||
return answers
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart"],
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
default=None,
|
||||
choices=["hf", "legacy"],
|
||||
type=str,
|
||||
help="RAG model retriever type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the retrieval index",
|
||||
)
|
||||
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_mode",
|
||||
choices=["e2e", "retrieval"],
|
||||
default="e2e",
|
||||
type=str,
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calulates precision@k.",
|
||||
)
|
||||
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a file containing evaluation samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a tab-separated file with gold samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_mode",
|
||||
default="qa",
|
||||
type=str,
|
||||
choices=["qa", "ans"],
|
||||
help="Format of the gold data file"
|
||||
"qa - a single line in the following format: question [tab] answer_list"
|
||||
"ans - a single line of the gold file contains the expected answer string",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictions_path",
|
||||
type=str,
|
||||
default="predictions.txt",
|
||||
help="Name of the predictions file, to be stored in the checkpoints directry",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recalculate",
|
||||
help="Recalculate predictions even if the prediction file exists",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of beams to be used when generating answers",
|
||||
)
|
||||
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
|
||||
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
|
||||
|
||||
parser.add_argument(
|
||||
"--print_predictions",
|
||||
action="store_true",
|
||||
help="If True, prints predictions while evaluating.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_docs",
|
||||
action="store_true",
|
||||
help="If True, prints docs retried while generating.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
model_kwargs = {}
|
||||
if args.model_type is None:
|
||||
args.model_type = infer_model_type(args.model_name_or_path)
|
||||
assert args.model_type is not None
|
||||
if args.model_type.startswith("rag"):
|
||||
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
|
||||
model_kwargs["n_docs"] = args.n_docs
|
||||
if args.index_name is not None:
|
||||
model_kwargs["index_name"] = args.index_name
|
||||
if args.index_path is not None:
|
||||
model_kwargs["index_path"] = args.index_path
|
||||
else:
|
||||
model_class = BartForConditionalGeneration
|
||||
|
||||
checkpoints = (
|
||||
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
|
||||
if args.eval_all_checkpoints
|
||||
else [args.model_name_or_path]
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
|
||||
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
if os.path.exists(args.predictions_path) and (not args.recalculate):
|
||||
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
continue
|
||||
|
||||
logger.info("***** Running evaluation for {} *****".format(checkpoint))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
|
||||
|
||||
if args.model_type.startswith("rag"):
|
||||
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
|
||||
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
|
||||
model.retriever.init_retrieval()
|
||||
else:
|
||||
model = model_class.from_pretrained(checkpoint, **model_kwargs)
|
||||
model.to(args.device)
|
||||
|
||||
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
|
||||
questions = []
|
||||
for line in tqdm(eval_file):
|
||||
questions.append(line.strip())
|
||||
if len(questions) == args.eval_batch_size:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers) + "\n")
|
||||
preds_file.flush()
|
||||
questions = []
|
||||
if len(questions) > 0:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers))
|
||||
preds_file.flush()
|
||||
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
474
examples/rag/finetune.py
Normal file
474
examples/rag/finetune.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
RagConfig,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
RagTokenizer,
|
||||
T5ForConditionalGeneration,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
||||
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip
|
||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
||||
Seq2SeqDataset,
|
||||
calculate_exact_match,
|
||||
is_rag_model,
|
||||
set_extra_model_params,
|
||||
)
|
||||
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
|
||||
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
|
||||
flatten_list,
|
||||
get_git_info,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
mode = "generative_qa"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["em"]
|
||||
val_metric = "em"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
|
||||
if isinstance(hparams, dict):
|
||||
hparams = AttrDict(hparams)
|
||||
if hparams.model_type == "rag_sequence":
|
||||
self.model_class = RagSequenceForGeneration
|
||||
elif hparams.model_type == "rag_token":
|
||||
self.model_class = RagTokenForGeneration
|
||||
elif hparams.model_type == "bart":
|
||||
self.model_class = BartForConditionalGeneration
|
||||
else:
|
||||
self.model_class = T5ForConditionalGeneration
|
||||
self.is_rag_model = is_rag_model(hparams.model_type)
|
||||
|
||||
config_class = RagConfig if self.is_rag_model else AutoConfig
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
if args.prefix is not None:
|
||||
config.generator.prefix = args.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
if args.prefix is not None:
|
||||
config.prefix = args.prefix
|
||||
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
prefix = config.prefix
|
||||
|
||||
tokenizer = (
|
||||
RagTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
if self.is_rag_model
|
||||
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
)
|
||||
|
||||
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
|
||||
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.distributed_port = self.hparams.distributed_port
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||
logger.info("Custom init_ddp_connection.")
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if self.is_rag_model:
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
|
||||
rag_kwargs = {}
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(target_ids)
|
||||
lm_labels = target_ids
|
||||
elif isinstance(self.model, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous()
|
||||
lm_labels = target_ids[:, 1:].clone()
|
||||
else:
|
||||
assert self.is_rag_model
|
||||
generator = self.model.rag.generator
|
||||
if isinstance(generator, T5ForConditionalGeneration):
|
||||
decoder_start_token_id = generator.config.decoder_start_token_id
|
||||
decoder_input_ids = (
|
||||
torch.cat(
|
||||
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
|
||||
dim=1,
|
||||
)
|
||||
if target_ids.shape[0] < self.target_lens["train"]
|
||||
else generator._shift_right(target_ids)
|
||||
)
|
||||
elif isinstance(generator, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids
|
||||
lm_labels = decoder_input_ids
|
||||
rag_kwargs["reduce_loss"] = True
|
||||
|
||||
assert decoder_input_ids is not None
|
||||
|
||||
outputs = self(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=False,
|
||||
labels=lm_labels,
|
||||
return_dict=True,
|
||||
**rag_kwargs,
|
||||
)
|
||||
|
||||
loss = outputs["loss"]
|
||||
return (loss,)
|
||||
|
||||
@property
|
||||
def pad(self) -> int:
|
||||
raise NotImplementedError("pad not implemented")
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
loss_tensors = self._step(batch)
|
||||
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
logs["tpb"] = (
|
||||
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
|
||||
)
|
||||
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
gen_metrics = {
|
||||
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
||||
}
|
||||
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
|
||||
gen_metrics.update({k: v.item() for k, v in losses.items()})
|
||||
|
||||
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
||||
metrics_tensor = metrics_tensor / dist.get_world_size()
|
||||
gen_metrics.update({self.val_metric: metrics_tensor.item()})
|
||||
|
||||
losses.update(gen_metrics)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_exact_match(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
start_time = time.time()
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
do_deduplication=False, # rag specific parameter
|
||||
use_cache=True,
|
||||
min_length=1,
|
||||
max_length=self.target_lens["val"],
|
||||
)
|
||||
|
||||
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
|
||||
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = Seq2SeqDataset(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
||||
sampler = dataset.make_sortish_sampler(batch_size)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if max(scheduler.get_last_lr()) > 0:
|
||||
warnings.warn("All learning rates are 0")
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prefix added at the beginning of each text, typically used with T5-based models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
||||
type=str,
|
||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args, model=None) -> GenerativeQAModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
es_callback = (
|
||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
if args.early_stopping_patience >= 0
|
||||
else False
|
||||
)
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
model.hparams.test_checkpoint = ""
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||
if checkpoints:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
34
examples/rag/finetune.sh
Executable file
34
examples/rag/finetune.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODLE_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 0.25 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1
|
||||
47
examples/rag/parse_dpr_relevance_data.py
Normal file
47
examples/rag/parse_dpr_relevance_data.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
This script reads DPR retriever training data and parses each datapoint. We save a line per datapoint.
|
||||
Each line consists of the query followed by a tab-separated list of Wikipedia page titles constituting
|
||||
positive contexts for a given query.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--src_path",
|
||||
type=str,
|
||||
default="biencoder-nq-dev.json",
|
||||
help="Path to raw DPR training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
type=str,
|
||||
help="where to store parsed evaluation_set file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
type=str,
|
||||
help="where to store parsed gold_data_path file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.src_path, "r") as src_file, open(args.evaluation_set, "w") as eval_file, open(
|
||||
args.gold_data_path, "w"
|
||||
) as gold_file:
|
||||
dpr_records = json.load(src_file)
|
||||
for dpr_record in tqdm(dpr_records):
|
||||
question = dpr_record["question"]
|
||||
contexts = [context["title"] for context in dpr_record["positive_ctxs"]]
|
||||
eval_file.write(question + "\n")
|
||||
gold_file.write("\t".join(contexts) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
examples/rag/requirements.txt
Normal file
4
examples/rag/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
faiss-cpu >= 1.6.3
|
||||
datasets >= 1.0.1
|
||||
psutil >= 5.7.0
|
||||
torch >= 1.4.0
|
||||
156
examples/rag/test_distributed_retriever.py
Normal file
156
examples/rag/test_distributed_retriever.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
|
||||
:class:`~transformers.RagRetriever`.
|
||||
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@require_distributed_retrieval
|
||||
class RagRetrieverTest(TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.retrieval_vector_size = 8
|
||||
|
||||
# DPR tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
||||
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
# BART tok
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
||||
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever:
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(list(doc_ids), [1, 0])
|
||||
187
examples/rag/utils.py
Normal file
187
examples/rag/utils.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import linecache
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from examples.seq2seq.utils import SortishSampler, trim_batch
|
||||
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
|
||||
tokenizer.padding_side = padding_side
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
add_special_tokens=True,
|
||||
**extra_kw,
|
||||
)
|
||||
|
||||
|
||||
class Seq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix = prefix
|
||||
if n_obs is not None:
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_lens)
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
|
||||
# Need to add eos token manually for T5
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
source_line += self.tokenizer.eos_token
|
||||
tgt_line += self.tokenizer.eos_token
|
||||
|
||||
# Pad source and target to the right
|
||||
source_tokenizer = (
|
||||
self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
)
|
||||
target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
|
||||
source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
|
||||
target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")
|
||||
|
||||
source_ids = source_inputs["input_ids"].squeeze()
|
||||
target_ids = target_inputs["input_ids"].squeeze()
|
||||
src_mask = source_inputs["attention_mask"].squeeze()
|
||||
return {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": src_mask,
|
||||
"decoder_input_ids": target_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
y = trim_batch(target_ids, tgt_pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
|
||||
batch = {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": source_mask,
|
||||
"decoder_input_ids": y,
|
||||
}
|
||||
return batch
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.src_lens, batch_size)
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
assert len(output_lns) == len(reference_lns)
|
||||
em = 0
|
||||
for hypo, pred in zip(output_lns, reference_lns):
|
||||
em += exact_match_score(hypo, pred)
|
||||
if len(output_lns) > 0:
|
||||
em /= len(output_lns)
|
||||
return {"em": em}
|
||||
|
||||
|
||||
def is_rag_model(model_prefix):
|
||||
return model_prefix.startswith("rag")
|
||||
|
||||
|
||||
def set_extra_model_params(extra_params, hparams, config):
|
||||
equivalent_param = {p: p for p in extra_params}
|
||||
# T5 models don't have `dropout` param, they have `dropout_rate` instead
|
||||
equivalent_param["dropout"] = "dropout_rate"
|
||||
for p in extra_params:
|
||||
if getattr(hparams, p, None):
|
||||
if not hasattr(config, p) and not hasattr(config, equivalent_param[p]):
|
||||
logger.info("config doesn't have a `{}` attribute".format(p))
|
||||
delattr(hparams, p)
|
||||
continue
|
||||
set_p = p if hasattr(config, p) else equivalent_param[p]
|
||||
setattr(config, set_p, getattr(hparams, p))
|
||||
delattr(hparams, p)
|
||||
return hparams, config
|
||||
@@ -8,7 +8,7 @@ tensorflow_datasets
|
||||
pytorch-lightning==0.8.5
|
||||
matplotlib
|
||||
git-python==1.0.3
|
||||
faiss
|
||||
faiss-cpu
|
||||
streamlit
|
||||
elasticsearch
|
||||
pandas
|
||||
|
||||
Reference in New Issue
Block a user