Eli5 examples (#4968)

* add eli5 examples

* add dense query script

* query_di

* merging

* merging

* add_utils

* adds nearest neighbor wikipedia

* batch queries

* training_retriever

* new notebooks

* moved retriever traiing script

* finished wiki40b

* max_len_fix

* train_s2s

* retriever_batch_checkpointing

* cleanup

* merge

* dim_fix

* fix_indexer

* fix_wiki40b_snippets

* fix_embed_for_r

* fp32 index

* fix_sparse_q

* joint_training

* remove obsolete datasets

* add_passage_nn_results

* add_passage_nn_results

* add_batch_nn

* add_batch_nn

* add_data_scripts

* notebook

* notebook

* notebook

* fix_multi_gpu

* add_app

* full_caching

* full_caching

* notebook

* sparse_done

* images

* notebook

* add_image_gif

* with_Gif

* add_contr_image

* notebook

* notebook

* notebook

* train_functions

* notebook

* min_retrieval_length

* pandas_option

* notebook

* min_retrieval_length

* notebook

* notebook

* eval_Retriever

* notebook

* images

* notebook

* add_example

* add_example

* notebook

* fireworks

* notebook

* notebook

* joe's notebook comments

* app_update

* notebook

* notebook_link

* captions

* notebook

* assing RetriBert model

* add RetriBert to Auto

* change AutoLMHead to AutoSeq2Seq

* notebook downloads from hf models

* style_black

* style_black

* app_update

* app_update

* fix_app_update

* style

* style

* isort

* Delete WikiELI5training.ipynb

* Delete evaluate_eli5.py

* Delete WikiELI5explore.ipynb

* Delete ExploreWikiELI5Support.html

* Delete explainlikeimfive.py

* Delete wiki_snippets.py

* children before parent

* children before parent

* style_black

* style_black_only

* isort

* isort_new

* Update src/transformers/modeling_retribert.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* typo fixes

* app_without_asset

* cleanup

* Delete ELI5animation.gif

* Delete ELI5contrastive.svg

* Delete ELI5wiki_index.svg

* Delete choco_bis.svg

* Delete fireworks.gif

* Delete huggingface_logo.jpg

* Delete huggingface_logo.svg

* Delete Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb

* Delete eli5_app.py

* Delete eli5_utils.py

* readme

* Update README.md

* unused imports

* moved_info

* default_beam

* ftuned model

* disclaimer

* Update src/transformers/modeling_retribert.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* black

* add_doc

* names

* isort_Examples

* isort_Examples

* Add doc to index

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Yacine Jernite
2020-06-16 16:36:58 -04:00
committed by GitHub
parent c3e607496c
commit 49c5202522
14 changed files with 1423 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
# Long Form Question Answering
This folder contains the code for the Long Form Question answering [demo](http://35.226.96.115:8080/) as well as methods to train and use a fully end-to-end Long Form Question Answering system using the [🤗transformers](https://github.com/huggingface/transformers) and [🤗nlp](https://github.com/huggingface/nlp) libraries.
You can use these mothods to train your own system by following along the associate [notebook](https://github.com/huggingface/notebooks/blob/master/longform-qa/Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb) or [blog post](https://yjernite.github.io/lfqa.html).

View File

@@ -0,0 +1,332 @@
import numpy as np
import torch
import faiss
import nlp
import streamlit as st
import transformers
from elasticsearch import Elasticsearch
from eli5_utils import (
embed_questions_for_retrieval,
make_qa_s2s_model,
qa_s2s_generate,
query_es_index,
query_qa_dense_index,
)
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
MODEL_TYPE = "bart"
LOAD_DENSE_INDEX = True
@st.cache(allow_output_mutation=True)
def load_models():
if LOAD_DENSE_INDEX:
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
_ = qar_model.eval()
else:
qar_tokenizer, qar_model = (None, None)
if MODEL_TYPE == "bart":
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
s2s_model.load_state_dict(save_dict["model"])
_ = s2s_model.eval()
else:
s2s_tokenizer, s2s_model = make_qa_s2s_model(
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
)
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)
@st.cache(allow_output_mutation=True)
def load_indexes():
if LOAD_DENSE_INDEX:
faiss_res = faiss.StandardGpuResources()
wiki40b_passages = nlp.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
wiki40b_passage_reps = np.memmap(
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
dtype="float32",
mode="r",
shape=(wiki40b_passages.num_rows, 128),
)
wiki40b_index_flat = faiss.IndexFlatIP(128)
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU
else:
wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)
@st.cache(allow_output_mutation=True)
def load_train_data():
eli5 = nlp.load_dataset("eli5", name="LFQA_reddit")
eli5_train = eli5["train_eli5"]
eli5_train_q_reps = np.memmap(
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
)
eli5_train_q_index = faiss.IndexFlatIP(128)
eli5_train_q_index.add(eli5_train_q_reps)
return (eli5_train, eli5_train_q_index)
passages, gpu_dense_index, es_client = load_indexes()
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
eli5_train, eli5_train_q_index = load_train_data()
def find_nearest_training(question, n_results=10):
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
D, I = eli5_train_q_index.search(q_rep, n_results)
nn_examples = [eli5_train[int(i)] for i in I[0]]
return nn_examples
def make_support(question, source="wiki40b", method="dense", n_results=10):
if source == "none":
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
else:
if method == "dense":
support_doc, hit_lst = query_qa_dense_index(
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
)
else:
support_doc, hit_lst = query_es_index(
question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results,
)
support_list = [
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
]
question_doc = "question: {} context: {}".format(question, support_doc)
return question_doc, support_list
@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)})
def answer_question(
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
):
with torch.no_grad():
answer = qa_s2s_generate(
question_doc,
s2s_model,
s2s_tokenizer,
num_answers=1,
num_beams=n_beams,
min_len=min_len,
max_len=max_len,
do_sample=sampling,
temp=temp,
top_p=top_p,
top_k=None,
max_input_length=1024,
device="cuda:0",
)[0]
return (answer, support_list)
st.title("Long Form Question Answering with ELI5")
# Start sidebar
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
header_full = """
<html>
<head>
<style>
.img-container {
padding-left: 90px;
padding-right: 90px;
padding-top: 50px;
padding-bottom: 50px;
background-color: #f0f3f9;
}
</style>
</head>
<body>
<span class="img-container"> <!-- Inline parent element -->
%s
</span>
</body>
</html>
""" % (
header_html,
)
st.sidebar.markdown(
header_full, unsafe_allow_html=True,
)
# Long Form QA with ELI5 and Wikipedia
description = """
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
a pre-processed fixed snapshot of Wikipedia.
"""
st.sidebar.markdown(description, unsafe_allow_html=True)
action_list = [
"Answer the question",
"View the retrieved document only",
"View the most similar ELI5 question and answer",
"Show me everything, please!",
]
demo_options = st.sidebar.checkbox("Demo options")
if demo_options:
action_st = st.sidebar.selectbox("", action_list, index=3,)
action = action_list.index(action_st)
show_type = st.sidebar.selectbox("", ["Show full text of passages", "Show passage section titles"], index=0,)
show_passages = show_type == "Show full text of passages"
else:
action = 3
show_passages = True
retrieval_options = st.sidebar.checkbox("Retrieval options")
if retrieval_options:
retriever_info = """
### Information retriever options
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
"""
st.sidebar.markdown(retriever_info)
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
else:
wiki_source = "wiki40b"
index_type = "dense"
sampled = "beam"
n_beams = 2
min_len = 64
max_len = 256
top_p = None
temp = None
generate_options = st.sidebar.checkbox("Generation options")
if generate_options:
generate_info = """
### Answer generation options
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
**beam** search, or **sample** from the decoder's output probabilities.
"""
st.sidebar.markdown(generate_info)
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
min_len = st.sidebar.slider(
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
)
max_len = st.sidebar.slider(
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
)
if sampled == "beam":
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
else:
top_p = st.sidebar.slider(
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
)
temp = st.sidebar.slider(
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
)
n_beams = None
# start main text
questions_list = [
"<MY QUESTION>",
"How do people make chocolate?",
"Why do we get a fever when we are sick?",
"How can different animals perceive different colors?",
"What is natural language processing?",
"What's the best way to treat a sunburn?",
"What exactly are vitamins ?",
"How does nuclear energy provide electricity?",
"What's the difference between viruses and bacteria?",
"Why are flutes classified as woodwinds when most of them are made out of metal ?",
"Why do people like drinking coffee even though it tastes so bad?",
"What happens when wine ages? How does it make the wine taste better?",
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
"How does New Zealand have so many large bird predators?",
]
question_s = st.selectbox(
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", questions_list, index=1,
)
if question_s == "<MY QUESTION>":
question = st.text_input("Enter your question here:", "")
else:
question = question_s
if st.button("Show me!"):
if action in [0, 1, 3]:
if index_type == "mixed":
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
support_list = []
for res_d, res_s in zip(support_list_dense, support_list_sparse):
if tuple(res_d) not in support_list:
support_list += [tuple(res_d)]
if tuple(res_s) not in support_list:
support_list += [tuple(res_s)]
support_list = support_list[:10]
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
else:
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
if action in [0, 3]:
answer, support_list = answer_question(
question_doc,
s2s_model,
s2s_tokenizer,
min_len=min_len,
max_len=int(max_len),
sampling=(sampled == "sampled"),
n_beams=n_beams,
top_p=top_p,
temp=temp,
)
st.markdown("### The model generated answer is:")
st.write(answer)
if action in [0, 1, 3] and wiki_source != "none":
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
for i, res in enumerate(support_list):
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
sec_titles = res[1].strip()
if sec_titles == "":
sections = "[{}]({})".format(res[0], wiki_url)
else:
sec_list = sec_titles.split(" & ")
sections = " & ".join(
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
)
st.markdown(
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections),
unsafe_allow_html=True,
)
if show_passages:
st.write(
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
)
if action in [2, 3]:
nn_train_list = find_nearest_training(question)
train_exple = nn_train_list[0]
st.markdown(
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
)
answers_st = [
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
if i == 0 or sc > 2
]
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
disclaimer = """
---
**Disclaimer**
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
"""
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)

View File

@@ -0,0 +1,653 @@
import functools
import math
import os # noqa: F401
from random import choice, randint
from time import time
import numpy as np
import torch
import torch.utils.checkpoint as checkpoint
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm
import faiss # noqa: F401
import nlp # noqa: F401
import pandas as pd
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
pd.set_option("display.max_colwidth", None)
###############
# Sparse index
###############
def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"):
index_config = {
"settings": {
"number_of_shards": 1,
"analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
},
"mappings": {
"properties": {
"article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
"section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
"passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
}
},
}
es_client.indices.create(index=index_name, body=index_config)
number_of_docs = passages_dset.num_rows
progress = tqdm(unit="docs", total=number_of_docs)
successes = 0
def passage_generator():
for passage in passages_dset:
yield passage
# create the ES index
for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):
progress.update(1)
successes += ok
print("Indexed %d documents" % (successes,))
def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20):
q = question.lower()
banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"]
q = " ".join([w for w in q.split() if w not in banned])
response = es_client.search(
index=index_name,
body={
"query": {
"multi_match": {
"query": q,
"fields": ["article_title", "section_title", "passage_text^2"],
"type": "cross_fields",
}
},
"size": 2 * n_results,
},
)
hits = response["hits"]["hits"]
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
for r, hit in zip(res_list, hits):
r["passage_id"] = hit["_id"]
r["score"] = hit["_score"]
r["passage_text"] = hit["_source"]["passage_text"]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
return support_doc, res_list
###############
# ELI5 retriever training
###############
class ELI5DatasetQARetriver(Dataset):
def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
self.data = examples_array
self.answer_thres = extra_answer_threshold
self.min_length = min_answer_length
self.training = training
self.n_samples = self.data.num_rows if n_samples is None else n_samples
def __len__(self):
return self.n_samples
def make_example(self, idx):
example = self.data[idx]
question = example["title"]
if self.training:
answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
answer_tab = choice(answers).split(" ")
start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
answer_span = " ".join(answer_tab[start_idx:])
else:
answer_span = example["answers"]["text"][0]
return (question, answer_span)
def __getitem__(self, idx):
return self.make_example(idx % self.data.num_rows)
class RetrievalQAEmbedder(torch.nn.Module):
def __init__(self, sent_encoder, dim):
super(RetrievalQAEmbedder, self).__init__()
self.sent_encoder = sent_encoder
self.output_dim = 128
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
# reproduces BERT forward pass with checkpointing
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
return self.sent_encoder(input_ids, attention_mask=attention_mask)[1]
else:
# prepare implicit variables
device = input_ids.device
input_shape = input_ids.size()
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
)
# define function for checkpointing
def partial_encode(*inputs):
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
sequence_output = encoder_outputs[0]
pooled_output = self.sent_encoder.pooler(sequence_output)
return pooled_output
# run embedding layer on everything at once
embedding_output = self.sent_encoder.embeddings(
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
)
# run encoding and pooling on one mini-batch at a time
pooled_output_list = []
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
pooled_output_list.append(pooled_output)
return torch.cat(pooled_output_list, dim=0)
def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
return self.project_q(q_reps)
def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
return self.project_a(a_reps)
def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
device = q_ids.device
q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
compare_scores = torch.mm(q_reps, a_reps.t())
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
loss = (loss_qa + loss_aq) / 2
return loss
def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name).to(device)
# run bert_model on a dummy batch to get output dimension
d_ids = torch.LongTensor(
[[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]]
).to(device)
d_mask = torch.LongTensor([[1]]).to(device)
sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1]
qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device)
if from_file is not None:
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
qa_embedder.load_state_dict(param_dict["model"])
return tokenizer, qa_embedder
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
return (q_ids, q_mask, a_ids, a_mask)
def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0):
model.train()
# make iterator
train_sampler = RandomSampler(dataset)
model_collate_fn = functools.partial(
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
for step, batch in enumerate(epoch_iterator):
q_ids, q_mask, a_ids, a_mask = batch
pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
loss = pre_loss.sum()
# optimizer
loss.backward()
optimizer.step()
scheduler.step()
model.zero_grad()
# some printing within the epoch
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0 or step == 1:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
loc_loss = 0
loc_steps = 0
def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0):
model.train()
model_collate_fn = functools.partial(
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
# make iterator
train_samplers = [RandomSampler(dataset) for dataset in dataset_list]
data_loaders = [
DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
for dataset, train_sampler in zip(dataset_list, train_samplers)
]
iterators = [iter(dloader) for dloader in data_loaders]
joint_iter = zip(*iterators)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
for step, (batches,) in enumerate(zip(joint_iter)):
for batch in batches:
q_ids, q_mask, a_ids, a_mask = batch
loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
# optimizer
loss.backward()
optimizer.step()
scheduler.step()
model.zero_grad()
# some printing within the epoch
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
loc_loss = 0
loc_steps = 0
def evaluate_qa_retriever(model, dataset, tokenizer, args):
model.eval()
# make iterator
eval_sampler = SequentialSampler(dataset)
model_collate_fn = functools.partial(
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
tot_loss = 0.0
with torch.no_grad():
for step, batch in enumerate(epoch_iterator):
q_ids, q_mask, a_ids, a_mask = batch
loss = model(q_ids, q_mask, a_ids, a_mask)
tot_loss += loss.item()
return tot_loss / (step + 1)
def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):
qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
qar_optimizer,
num_warmup_steps=100,
num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),
)
for e in range(qar_args.num_epochs):
train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)
m_save_dict = {
"model": qar_model.state_dict(),
"optimizer": qar_optimizer.state_dict(),
"scheduler": qar_scheduler.state_dict(),
}
print("Saving model {}".format(qar_args.model_save_name))
torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))
eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))
###############
# ELI5 seq2seq model training
###############
class ELI5DatasetS2S(Dataset):
def __init__(
self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True
):
self.training = training
self.data = examples_array
self.make_doc_function = make_doc_fun
self.document_cache = {} if document_cache is None else document_cache
assert not (make_doc_fun is None and document_cache is None)
# make index of specific question-answer pairs from multi-answers
if self.training:
self.qa_id_list = [
(i, j)
for i, qa in enumerate(self.data)
for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"]))
if j == 0 or sc >= extra_answer_threshold
]
else:
self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)]
def __len__(self):
return len(self.qa_id_list)
def make_example(self, idx):
i, j = self.qa_id_list[idx]
example = self.data[i]
question = example["title"] + " " + example["selftext"]
answer = example["answers"]["text"][j]
q_id = example["q_id"]
if self.make_doc_function is not None:
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
document = self.document_cache[q_id]
in_st = "question: {} context: {}".format(
question.lower().replace(" --t--", "").strip(), document.lower().strip(),
)
out_st = answer
return (in_st, out_st)
def __getitem__(self, idx):
return self.make_example(idx)
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if from_file is not None:
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
model.load_state_dict(param_dict["model"])
return tokenizer, model
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
lm_labels = a_ids[:, 1:].contiguous().clone()
lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
model_inputs = {
"input_ids": q_ids,
"attention_mask": q_mask,
"decoder_input_ids": a_ids[:, :-1].contiguous(),
"lm_labels": lm_labels,
}
return model_inputs
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
model.train()
# make iterator
if curriculum:
train_sampler = SequentialSampler(dataset)
else:
train_sampler = RandomSampler(dataset)
model_collate_fn = functools.partial(
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
for step, batch_inputs in enumerate(epoch_iterator):
pre_loss = model(**batch_inputs)[0]
loss = pre_loss.sum() / pre_loss.shape[0]
loss.backward()
# optimizer
if step % args.backward_freq == 0:
optimizer.step()
scheduler.step()
model.zero_grad()
# some printing within the epoch
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0 or step == 1:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
loc_loss = 0
loc_steps = 0
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
model.eval()
# make iterator
train_sampler = SequentialSampler(dataset)
model_collate_fn = functools.partial(
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
with torch.no_grad():
for step, batch_inputs in enumerate(epoch_iterator):
pre_loss = model(**batch_inputs)[0]
loss = pre_loss.sum() / pre_loss.shape[0]
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0:
print(
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
s2s_scheduler = get_linear_schedule_with_warmup(
s2s_optimizer,
num_warmup_steps=400,
num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
)
for e in range(s2s_args.num_epochs):
train_qa_s2s_epoch(
qa_s2s_model,
s2s_train_dset,
qa_s2s_tokenizer,
s2s_optimizer,
s2s_scheduler,
s2s_args,
e,
curriculum=(e == 0),
)
m_save_dict = {
"model": qa_s2s_model.state_dict(),
"optimizer": s2s_optimizer.state_dict(),
"scheduler": s2s_scheduler.state_dict(),
}
print("Saving model {}".format(s2s_args.model_save_name))
eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))
# generate answer from input "question: ... context: <p> ..."
def qa_s2s_generate(
question_doc,
qa_s2s_model,
qa_s2s_tokenizer,
num_answers=1,
num_beams=None,
min_len=64,
max_len=256,
do_sample=False,
temp=1.0,
top_p=None,
top_k=None,
max_input_length=512,
device="cuda:0",
):
model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
generated_ids = qa_s2s_model.generate(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
min_length=min_len,
max_length=max_len,
do_sample=do_sample,
early_stopping=True,
num_beams=1 if do_sample else n_beams,
temperature=temp,
top_k=top_k,
top_p=top_p,
eos_token_id=qa_s2s_tokenizer.eos_token_id,
no_repeat_ngram_size=3,
num_return_sequences=num_answers,
decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
)
return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]
###############
# ELI5-trained retrieval model usage
###############
def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):
a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
with torch.no_grad():
a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float)
return a_reps.numpy()
def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
with torch.no_grad():
q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float)
return q_reps.numpy()
def make_qa_dense_index(
qa_embedder,
tokenizer,
passages_dset,
batch_size=512,
max_length=128,
index_name="kilt_passages_reps.dat",
dtype="float32",
device="cuda:0",
):
st_time = time()
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
n_batches = math.ceil(passages_dset.num_rows / batch_size)
for i in range(n_batches):
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
fp[i * batch_size : (i + 1) * batch_size] = reps
if i % 50 == 0:
print(i, time() - st_time)
def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False):
total_retriever_time = 0.0
total_retriever_score = 0.0
st_time = time()
for i, (question, answer) in enumerate(qa_list):
r_time = time()
retrieved_passages = retriever_func(question, n_ret)
total_retriever_time += time() - r_time
total_retriever_score += scoring_func(retrieved_passages, answer)
if verbose and ((i + 1) % 500 == 0 or i <= 1):
print(
"{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format(
i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time
)
)
return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)}
# build a support document for the question out of Wikipedia snippets
def query_qa_dense_index(
question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0"
):
q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device)
D, I = wiki_index.search(q_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc in zip(res_list, D[0]):
r["score"] = float(sc)
return support_doc, res_list
def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder)
D, I = wiki_index.search(q_rep, n_results)
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
support_doc_lst = [
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
for (res_passages, dl) in zip(res_passages_lst, D):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc in zip(res_list, dl):
r["score"] = float(sc)
all_res_lists += [res_list[:]]
return support_doc_lst, all_res_lists
# find nearest neighbors of an answer or declarative text in Wikipedia snippets
def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20):
a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder)
D, I = wiki_index.search(a_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc, i in zip(res_list, D[0], I[0]):
r["passage_id"] = int(i)
r["score"] = float(sc)
return support_doc, res_list
def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder)
D, I = wiki_index.search(a_reps, n_results)
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
support_doc_lst = [
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
for (res_passages, dl, il) in zip(res_passages_lst, D, I):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc, i in zip(res_list, dl, il):
r["passage_id"] = int(i)
r["score"] = float(sc)
all_res_lists += [res_list[:]]
return support_doc_lst, all_res_lists