Eli5 examples (#4968)
* add eli5 examples * add dense query script * query_di * merging * merging * add_utils * adds nearest neighbor wikipedia * batch queries * training_retriever * new notebooks * moved retriever traiing script * finished wiki40b * max_len_fix * train_s2s * retriever_batch_checkpointing * cleanup * merge * dim_fix * fix_indexer * fix_wiki40b_snippets * fix_embed_for_r * fp32 index * fix_sparse_q * joint_training * remove obsolete datasets * add_passage_nn_results * add_passage_nn_results * add_batch_nn * add_batch_nn * add_data_scripts * notebook * notebook * notebook * fix_multi_gpu * add_app * full_caching * full_caching * notebook * sparse_done * images * notebook * add_image_gif * with_Gif * add_contr_image * notebook * notebook * notebook * train_functions * notebook * min_retrieval_length * pandas_option * notebook * min_retrieval_length * notebook * notebook * eval_Retriever * notebook * images * notebook * add_example * add_example * notebook * fireworks * notebook * notebook * joe's notebook comments * app_update * notebook * notebook_link * captions * notebook * assing RetriBert model * add RetriBert to Auto * change AutoLMHead to AutoSeq2Seq * notebook downloads from hf models * style_black * style_black * app_update * app_update * fix_app_update * style * style * isort * Delete WikiELI5training.ipynb * Delete evaluate_eli5.py * Delete WikiELI5explore.ipynb * Delete ExploreWikiELI5Support.html * Delete explainlikeimfive.py * Delete wiki_snippets.py * children before parent * children before parent * style_black * style_black_only * isort * isort_new * Update src/transformers/modeling_retribert.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * typo fixes * app_without_asset * cleanup * Delete ELI5animation.gif * Delete ELI5contrastive.svg * Delete ELI5wiki_index.svg * Delete choco_bis.svg * Delete fireworks.gif * Delete huggingface_logo.jpg * Delete huggingface_logo.svg * Delete Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb * Delete eli5_app.py * Delete eli5_utils.py * readme * Update README.md * unused imports * moved_info * default_beam * ftuned model * disclaimer * Update src/transformers/modeling_retribert.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * black * add_doc * names * isort_Examples * isort_Examples * Add doc to index Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
332
examples/longform-qa/eli5_app.py
Normal file
332
examples/longform-qa/eli5_app.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import faiss
|
||||
import nlp
|
||||
import streamlit as st
|
||||
import transformers
|
||||
from elasticsearch import Elasticsearch
|
||||
from eli5_utils import (
|
||||
embed_questions_for_retrieval,
|
||||
make_qa_s2s_model,
|
||||
qa_s2s_generate,
|
||||
query_es_index,
|
||||
query_qa_dense_index,
|
||||
)
|
||||
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
MODEL_TYPE = "bart"
|
||||
LOAD_DENSE_INDEX = True
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def load_models():
|
||||
if LOAD_DENSE_INDEX:
|
||||
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
|
||||
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
|
||||
_ = qar_model.eval()
|
||||
else:
|
||||
qar_tokenizer, qar_model = (None, None)
|
||||
if MODEL_TYPE == "bart":
|
||||
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
|
||||
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
|
||||
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
|
||||
s2s_model.load_state_dict(save_dict["model"])
|
||||
_ = s2s_model.eval()
|
||||
else:
|
||||
s2s_tokenizer, s2s_model = make_qa_s2s_model(
|
||||
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
|
||||
)
|
||||
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def load_indexes():
|
||||
if LOAD_DENSE_INDEX:
|
||||
faiss_res = faiss.StandardGpuResources()
|
||||
wiki40b_passages = nlp.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
|
||||
wiki40b_passage_reps = np.memmap(
|
||||
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
|
||||
dtype="float32",
|
||||
mode="r",
|
||||
shape=(wiki40b_passages.num_rows, 128),
|
||||
)
|
||||
wiki40b_index_flat = faiss.IndexFlatIP(128)
|
||||
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
|
||||
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU
|
||||
else:
|
||||
wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
|
||||
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
|
||||
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def load_train_data():
|
||||
eli5 = nlp.load_dataset("eli5", name="LFQA_reddit")
|
||||
eli5_train = eli5["train_eli5"]
|
||||
eli5_train_q_reps = np.memmap(
|
||||
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
|
||||
)
|
||||
eli5_train_q_index = faiss.IndexFlatIP(128)
|
||||
eli5_train_q_index.add(eli5_train_q_reps)
|
||||
return (eli5_train, eli5_train_q_index)
|
||||
|
||||
|
||||
passages, gpu_dense_index, es_client = load_indexes()
|
||||
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
|
||||
eli5_train, eli5_train_q_index = load_train_data()
|
||||
|
||||
|
||||
def find_nearest_training(question, n_results=10):
|
||||
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
|
||||
D, I = eli5_train_q_index.search(q_rep, n_results)
|
||||
nn_examples = [eli5_train[int(i)] for i in I[0]]
|
||||
return nn_examples
|
||||
|
||||
|
||||
def make_support(question, source="wiki40b", method="dense", n_results=10):
|
||||
if source == "none":
|
||||
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
|
||||
else:
|
||||
if method == "dense":
|
||||
support_doc, hit_lst = query_qa_dense_index(
|
||||
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
|
||||
)
|
||||
else:
|
||||
support_doc, hit_lst = query_es_index(
|
||||
question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results,
|
||||
)
|
||||
support_list = [
|
||||
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
|
||||
]
|
||||
question_doc = "question: {} context: {}".format(question, support_doc)
|
||||
return question_doc, support_list
|
||||
|
||||
|
||||
@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)})
|
||||
def answer_question(
|
||||
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
|
||||
):
|
||||
with torch.no_grad():
|
||||
answer = qa_s2s_generate(
|
||||
question_doc,
|
||||
s2s_model,
|
||||
s2s_tokenizer,
|
||||
num_answers=1,
|
||||
num_beams=n_beams,
|
||||
min_len=min_len,
|
||||
max_len=max_len,
|
||||
do_sample=sampling,
|
||||
temp=temp,
|
||||
top_p=top_p,
|
||||
top_k=None,
|
||||
max_input_length=1024,
|
||||
device="cuda:0",
|
||||
)[0]
|
||||
return (answer, support_list)
|
||||
|
||||
|
||||
st.title("Long Form Question Answering with ELI5")
|
||||
|
||||
# Start sidebar
|
||||
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
|
||||
header_full = """
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
.img-container {
|
||||
padding-left: 90px;
|
||||
padding-right: 90px;
|
||||
padding-top: 50px;
|
||||
padding-bottom: 50px;
|
||||
background-color: #f0f3f9;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<span class="img-container"> <!-- Inline parent element -->
|
||||
%s
|
||||
</span>
|
||||
</body>
|
||||
</html>
|
||||
""" % (
|
||||
header_html,
|
||||
)
|
||||
st.sidebar.markdown(
|
||||
header_full, unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Long Form QA with ELI5 and Wikipedia
|
||||
description = """
|
||||
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
|
||||
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
|
||||
a pre-processed fixed snapshot of Wikipedia.
|
||||
"""
|
||||
st.sidebar.markdown(description, unsafe_allow_html=True)
|
||||
|
||||
action_list = [
|
||||
"Answer the question",
|
||||
"View the retrieved document only",
|
||||
"View the most similar ELI5 question and answer",
|
||||
"Show me everything, please!",
|
||||
]
|
||||
demo_options = st.sidebar.checkbox("Demo options")
|
||||
if demo_options:
|
||||
action_st = st.sidebar.selectbox("", action_list, index=3,)
|
||||
action = action_list.index(action_st)
|
||||
show_type = st.sidebar.selectbox("", ["Show full text of passages", "Show passage section titles"], index=0,)
|
||||
show_passages = show_type == "Show full text of passages"
|
||||
else:
|
||||
action = 3
|
||||
show_passages = True
|
||||
|
||||
retrieval_options = st.sidebar.checkbox("Retrieval options")
|
||||
if retrieval_options:
|
||||
retriever_info = """
|
||||
### Information retriever options
|
||||
|
||||
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
|
||||
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
|
||||
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
|
||||
"""
|
||||
st.sidebar.markdown(retriever_info)
|
||||
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
|
||||
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
|
||||
else:
|
||||
wiki_source = "wiki40b"
|
||||
index_type = "dense"
|
||||
|
||||
sampled = "beam"
|
||||
n_beams = 2
|
||||
min_len = 64
|
||||
max_len = 256
|
||||
top_p = None
|
||||
temp = None
|
||||
generate_options = st.sidebar.checkbox("Generation options")
|
||||
if generate_options:
|
||||
generate_info = """
|
||||
### Answer generation options
|
||||
|
||||
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
|
||||
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
|
||||
**beam** search, or **sample** from the decoder's output probabilities.
|
||||
"""
|
||||
st.sidebar.markdown(generate_info)
|
||||
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
|
||||
min_len = st.sidebar.slider(
|
||||
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
|
||||
)
|
||||
max_len = st.sidebar.slider(
|
||||
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
|
||||
)
|
||||
if sampled == "beam":
|
||||
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
|
||||
else:
|
||||
top_p = st.sidebar.slider(
|
||||
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
|
||||
)
|
||||
temp = st.sidebar.slider(
|
||||
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
|
||||
)
|
||||
n_beams = None
|
||||
|
||||
# start main text
|
||||
questions_list = [
|
||||
"<MY QUESTION>",
|
||||
"How do people make chocolate?",
|
||||
"Why do we get a fever when we are sick?",
|
||||
"How can different animals perceive different colors?",
|
||||
"What is natural language processing?",
|
||||
"What's the best way to treat a sunburn?",
|
||||
"What exactly are vitamins ?",
|
||||
"How does nuclear energy provide electricity?",
|
||||
"What's the difference between viruses and bacteria?",
|
||||
"Why are flutes classified as woodwinds when most of them are made out of metal ?",
|
||||
"Why do people like drinking coffee even though it tastes so bad?",
|
||||
"What happens when wine ages? How does it make the wine taste better?",
|
||||
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
|
||||
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
|
||||
"How does New Zealand have so many large bird predators?",
|
||||
]
|
||||
question_s = st.selectbox(
|
||||
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", questions_list, index=1,
|
||||
)
|
||||
if question_s == "<MY QUESTION>":
|
||||
question = st.text_input("Enter your question here:", "")
|
||||
else:
|
||||
question = question_s
|
||||
|
||||
if st.button("Show me!"):
|
||||
if action in [0, 1, 3]:
|
||||
if index_type == "mixed":
|
||||
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
|
||||
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
|
||||
support_list = []
|
||||
for res_d, res_s in zip(support_list_dense, support_list_sparse):
|
||||
if tuple(res_d) not in support_list:
|
||||
support_list += [tuple(res_d)]
|
||||
if tuple(res_s) not in support_list:
|
||||
support_list += [tuple(res_s)]
|
||||
support_list = support_list[:10]
|
||||
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
|
||||
else:
|
||||
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
|
||||
if action in [0, 3]:
|
||||
answer, support_list = answer_question(
|
||||
question_doc,
|
||||
s2s_model,
|
||||
s2s_tokenizer,
|
||||
min_len=min_len,
|
||||
max_len=int(max_len),
|
||||
sampling=(sampled == "sampled"),
|
||||
n_beams=n_beams,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
)
|
||||
st.markdown("### The model generated answer is:")
|
||||
st.write(answer)
|
||||
if action in [0, 1, 3] and wiki_source != "none":
|
||||
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
|
||||
for i, res in enumerate(support_list):
|
||||
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
|
||||
sec_titles = res[1].strip()
|
||||
if sec_titles == "":
|
||||
sections = "[{}]({})".format(res[0], wiki_url)
|
||||
else:
|
||||
sec_list = sec_titles.split(" & ")
|
||||
sections = " & ".join(
|
||||
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
|
||||
)
|
||||
st.markdown(
|
||||
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
if show_passages:
|
||||
st.write(
|
||||
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
|
||||
)
|
||||
if action in [2, 3]:
|
||||
nn_train_list = find_nearest_training(question)
|
||||
train_exple = nn_train_list[0]
|
||||
st.markdown(
|
||||
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
|
||||
)
|
||||
answers_st = [
|
||||
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
|
||||
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
|
||||
if i == 0 or sc > 2
|
||||
]
|
||||
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
|
||||
|
||||
|
||||
disclaimer = """
|
||||
---
|
||||
|
||||
**Disclaimer**
|
||||
|
||||
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
|
||||
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
|
||||
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
|
||||
"""
|
||||
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)
|
||||
Reference in New Issue
Block a user