Allow Custom Dataset in RAG Retriever (#7763)
* add CustomHFIndex * typo in config * update tests * add custom dataset example * clean script * update test data * minor in test * docs * docs * style * fix imports * allow to pass the indexed dataset directly * update tests * use multiset DPR * address thom and patrick's comments * style * update dpr tokenizer * add output_dir flag in use_own_knowledge_dataset.py * allow custom datasets in examples/rag/finetune.py * add test for custom dataset in distributed rag retriever
This commit is contained in:
@@ -27,13 +27,18 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
super().__init__(
|
||||
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
)
|
||||
|
||||
self.process_group = None
|
||||
|
||||
@@ -90,6 +90,11 @@ class GenerativeQAModule(BaseTransformer):
|
||||
config_class = RagConfig if self.is_rag_model else AutoConfig
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set retriever parameters
|
||||
config.index_name = args.index_name or config.index_name
|
||||
config.passages_path = args.passages_path or config.passages_path
|
||||
config.index_path = args.index_path or config.index_path
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
@@ -97,7 +102,7 @@ class GenerativeQAModule(BaseTransformer):
|
||||
config.generator.prefix = args.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
@@ -405,6 +410,28 @@ class GenerativeQAModule(BaseTransformer):
|
||||
)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_retriever_specific_args(parser):
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--passages_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args, model=None) -> GenerativeQAModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
@@ -465,6 +492,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
2
examples/rag/test_data/my_knowledge_dataset.csv
Normal file
2
examples/rag/test_data/my_knowledge_dataset.csv
Normal file
@@ -0,0 +1,2 @@
|
||||
Aaron Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' "prophet" (Exodus 4:10-17; 7:1). At the command of Moses, he let his rod turn into a snake. Then he stretched out his rod in order to bring on the first three plagues. After that, Moses tended to act and speak for himself. During the journey in the wilderness, Aaron was not always prominent or active. At the battle with Amalek, he was chosen with Hur to support the hand of Moses that held the "rod of God". When the revelation was given to Moses at biblical Mount Sinai, he headed the elders of Israel who accompanied Moses on the way to the summit.
|
||||
"Pokémon" Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The original video game series is the second best-selling video game franchise (behind Nintendo's "Mario" franchise) with more than 300million copies sold and over 800million mobile downloads. In addition, the "Pokémon" franchise includes the world's top-selling toy brand, the top-selling trading card game with over 25.7billion cards sold, an anime television series that has become the most successful video game adaptation with over 20 seasons and 1,000 episodes in 124 countries, as well as an anime film series, a , books, manga comics, music, and merchandise. The franchise is also represented in other Nintendo media, such as the "Super Smash Bros." series. In November 2005, 4Kids Entertainment, which had managed the non-game related licensing of "Pokémon", announced that it had agreed not to renew the "Pokémon" representation agreement. The Pokémon Company International oversees all "Pokémon" licensing outside Asia.
|
||||
|
Can't render this file because it contains an unexpected character in line 1 and column 35.
|
@@ -15,6 +15,7 @@ from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.retrieval_rag import CustomHFIndex
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
@@ -114,7 +115,7 @@ class RagRetrieverTest(TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever:
|
||||
def get_dummy_dataset(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
@@ -124,6 +125,12 @@ class RagRetrieverTest(TestCase):
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
return dataset
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(
|
||||
self, init_retrieval: bool, port=12345
|
||||
) -> RagPyTorchDistributedRetriever:
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
@@ -140,6 +147,37 @@ class RagRetrieverTest(TestCase):
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
else:
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
@@ -154,3 +192,33 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
199
examples/rag/use_own_knowledge_dataset.py
Normal file
199
examples/rag/use_own_knowledge_dataset.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
DPRContextEncoder,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
HfArgumentParser,
|
||||
RagRetriever,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenizer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
torch.set_grad_enabled(False)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def split_text(text: str, n=100, character=" ") -> List[str]:
|
||||
"""Split the text every ``n``-th occurence of ``character``"""
|
||||
text = text.split(character)
|
||||
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
||||
|
||||
|
||||
def split_documents(documents: dict) -> dict:
|
||||
"""Split documents into passages"""
|
||||
titles, texts = [], []
|
||||
for title, text in zip(documents["title"], documents["text"]):
|
||||
for passage in split_text(text):
|
||||
titles.append(title)
|
||||
texts.append(passage)
|
||||
return {"title": titles, "text": texts}
|
||||
|
||||
|
||||
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
|
||||
"""Compute the DPR embeddings of document passages"""
|
||||
input_ids = ctx_tokenizer(
|
||||
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
||||
)["input_ids"]
|
||||
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
||||
return {"embeddings": embeddings.detach().cpu().numpy()}
|
||||
|
||||
|
||||
def main(
|
||||
rag_example_args: "RagExampleArguments",
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
# The dataset needed for RAG must have three columns:
|
||||
# - title (string): title of the document
|
||||
# - text (string): text of a passage of the document
|
||||
# - embeddings (array of dimension d): DPR representation of the passage
|
||||
|
||||
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
|
||||
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
|
||||
|
||||
# You can load a Dataset object this way
|
||||
dataset = load_dataset(
|
||||
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
||||
)
|
||||
|
||||
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
|
||||
|
||||
# Then split the documents into passages of 100 words
|
||||
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
|
||||
|
||||
# And compute the embeddings
|
||||
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
||||
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
||||
dataset = dataset.map(
|
||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
||||
batched=True,
|
||||
batch_size=processing_args.batch_size,
|
||||
)
|
||||
|
||||
# And finally save your dataset
|
||||
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
|
||||
dataset.save_to_disk(passages_path)
|
||||
# from datasets import load_from_disk
|
||||
# dataset = load_from_disk(passages_path) # to reload the dataset
|
||||
|
||||
######################################
|
||||
logger.info("Step 2 - Index the dataset")
|
||||
######################################
|
||||
|
||||
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
|
||||
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
|
||||
dataset.add_faiss_index("embeddings", custom_index=index)
|
||||
|
||||
# And save the index
|
||||
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
|
||||
dataset.get_index("embeddings").save(index_path)
|
||||
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
|
||||
|
||||
######################################
|
||||
logger.info("Step 3 - Load RAG")
|
||||
######################################
|
||||
|
||||
# Easy way to load the model
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset
|
||||
)
|
||||
model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever)
|
||||
tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)
|
||||
|
||||
# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
|
||||
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)
|
||||
|
||||
######################################
|
||||
logger.info("Step 4 - Have fun")
|
||||
######################################
|
||||
|
||||
question = rag_example_args.question or "What does Moses' rod turn into ?"
|
||||
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
|
||||
generated = model.generate(input_ids)
|
||||
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
|
||||
logger.info("Q: " + question)
|
||||
logger.info("A: " + generated_string)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagExampleArguments:
|
||||
csv_path: str = field(
|
||||
default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"),
|
||||
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
|
||||
)
|
||||
question: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
|
||||
)
|
||||
rag_model_name: str = field(
|
||||
default="facebook/rag-sequence-nq",
|
||||
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
|
||||
)
|
||||
dpr_ctx_encoder_model_name: str = field(
|
||||
default="facebook/dpr-ctx_encoder-multiset-base",
|
||||
metadata={
|
||||
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
|
||||
},
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingArguments:
|
||||
num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The number of processes to use to split the documents into passages. Default is single process."
|
||||
},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexHnswArguments:
|
||||
d: int = field(
|
||||
default=768,
|
||||
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
|
||||
)
|
||||
m: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
|
||||
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
|
||||
main(rag_example_args, processing_args, index_hnsw_args)
|
||||
Reference in New Issue
Block a user