Eli5 examples (#4968)
* add eli5 examples * add dense query script * query_di * merging * merging * add_utils * adds nearest neighbor wikipedia * batch queries * training_retriever * new notebooks * moved retriever traiing script * finished wiki40b * max_len_fix * train_s2s * retriever_batch_checkpointing * cleanup * merge * dim_fix * fix_indexer * fix_wiki40b_snippets * fix_embed_for_r * fp32 index * fix_sparse_q * joint_training * remove obsolete datasets * add_passage_nn_results * add_passage_nn_results * add_batch_nn * add_batch_nn * add_data_scripts * notebook * notebook * notebook * fix_multi_gpu * add_app * full_caching * full_caching * notebook * sparse_done * images * notebook * add_image_gif * with_Gif * add_contr_image * notebook * notebook * notebook * train_functions * notebook * min_retrieval_length * pandas_option * notebook * min_retrieval_length * notebook * notebook * eval_Retriever * notebook * images * notebook * add_example * add_example * notebook * fireworks * notebook * notebook * joe's notebook comments * app_update * notebook * notebook_link * captions * notebook * assing RetriBert model * add RetriBert to Auto * change AutoLMHead to AutoSeq2Seq * notebook downloads from hf models * style_black * style_black * app_update * app_update * fix_app_update * style * style * isort * Delete WikiELI5training.ipynb * Delete evaluate_eli5.py * Delete WikiELI5explore.ipynb * Delete ExploreWikiELI5Support.html * Delete explainlikeimfive.py * Delete wiki_snippets.py * children before parent * children before parent * style_black * style_black_only * isort * isort_new * Update src/transformers/modeling_retribert.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * typo fixes * app_without_asset * cleanup * Delete ELI5animation.gif * Delete ELI5contrastive.svg * Delete ELI5wiki_index.svg * Delete choco_bis.svg * Delete fireworks.gif * Delete huggingface_logo.jpg * Delete huggingface_logo.svg * Delete Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb * Delete eli5_app.py * Delete eli5_utils.py * readme * Update README.md * unused imports * moved_info * default_beam * ftuned model * disclaimer * Update src/transformers/modeling_retribert.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * black * add_doc * names * isort_Examples * isort_Examples * Add doc to index Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -111,3 +111,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
|
|||||||
model_doc/reformer
|
model_doc/reformer
|
||||||
model_doc/marian
|
model_doc/marian
|
||||||
model_doc/longformer
|
model_doc/longformer
|
||||||
|
model_doc/retribert
|
||||||
|
|||||||
39
docs/source/model_doc/retribert.rst
Normal file
39
docs/source/model_doc/retribert.rst
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
RetriBERT
|
||||||
|
----------------------------------------------------
|
||||||
|
|
||||||
|
Overview
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The RetriBERT model was proposed in the blog post
|
||||||
|
`Explain Anything Like I'm Five: A Model for Open Domain Long Form Question Answering <https://yjernite.github.io/lfqa.html>`__,
|
||||||
|
RetriBERT is a small model that uses either a single or pair of Bert encoders with lower-dimension projection for dense semantic indexing of text.
|
||||||
|
|
||||||
|
Code to train and use the model can be found `here <https://github.com/huggingface/transformers/tree/master/examples/distillation>`_.
|
||||||
|
|
||||||
|
|
||||||
|
RetriBertConfig
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.RetriBertConfig
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
RetriBertTokenizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.RetriBertTokenizer
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
RetriBertTokenizerFast
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.RetriBertTokenizerFast
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
RetriBertModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.RetriBertModel
|
||||||
|
:members:
|
||||||
5
examples/longform-qa/README.md
Normal file
5
examples/longform-qa/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Long Form Question Answering
|
||||||
|
|
||||||
|
This folder contains the code for the Long Form Question answering [demo](http://35.226.96.115:8080/) as well as methods to train and use a fully end-to-end Long Form Question Answering system using the [🤗transformers](https://github.com/huggingface/transformers) and [🤗nlp](https://github.com/huggingface/nlp) libraries.
|
||||||
|
|
||||||
|
You can use these mothods to train your own system by following along the associate [notebook](https://github.com/huggingface/notebooks/blob/master/longform-qa/Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb) or [blog post](https://yjernite.github.io/lfqa.html).
|
||||||
332
examples/longform-qa/eli5_app.py
Normal file
332
examples/longform-qa/eli5_app.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import nlp
|
||||||
|
import streamlit as st
|
||||||
|
import transformers
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from eli5_utils import (
|
||||||
|
embed_questions_for_retrieval,
|
||||||
|
make_qa_s2s_model,
|
||||||
|
qa_s2s_generate,
|
||||||
|
query_es_index,
|
||||||
|
query_qa_dense_index,
|
||||||
|
)
|
||||||
|
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_TYPE = "bart"
|
||||||
|
LOAD_DENSE_INDEX = True
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache(allow_output_mutation=True)
|
||||||
|
def load_models():
|
||||||
|
if LOAD_DENSE_INDEX:
|
||||||
|
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
|
||||||
|
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
|
||||||
|
_ = qar_model.eval()
|
||||||
|
else:
|
||||||
|
qar_tokenizer, qar_model = (None, None)
|
||||||
|
if MODEL_TYPE == "bart":
|
||||||
|
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
|
||||||
|
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
|
||||||
|
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
|
||||||
|
s2s_model.load_state_dict(save_dict["model"])
|
||||||
|
_ = s2s_model.eval()
|
||||||
|
else:
|
||||||
|
s2s_tokenizer, s2s_model = make_qa_s2s_model(
|
||||||
|
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
|
||||||
|
)
|
||||||
|
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache(allow_output_mutation=True)
|
||||||
|
def load_indexes():
|
||||||
|
if LOAD_DENSE_INDEX:
|
||||||
|
faiss_res = faiss.StandardGpuResources()
|
||||||
|
wiki40b_passages = nlp.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
|
||||||
|
wiki40b_passage_reps = np.memmap(
|
||||||
|
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
|
||||||
|
dtype="float32",
|
||||||
|
mode="r",
|
||||||
|
shape=(wiki40b_passages.num_rows, 128),
|
||||||
|
)
|
||||||
|
wiki40b_index_flat = faiss.IndexFlatIP(128)
|
||||||
|
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
|
||||||
|
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU
|
||||||
|
else:
|
||||||
|
wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
|
||||||
|
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
|
||||||
|
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache(allow_output_mutation=True)
|
||||||
|
def load_train_data():
|
||||||
|
eli5 = nlp.load_dataset("eli5", name="LFQA_reddit")
|
||||||
|
eli5_train = eli5["train_eli5"]
|
||||||
|
eli5_train_q_reps = np.memmap(
|
||||||
|
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
|
||||||
|
)
|
||||||
|
eli5_train_q_index = faiss.IndexFlatIP(128)
|
||||||
|
eli5_train_q_index.add(eli5_train_q_reps)
|
||||||
|
return (eli5_train, eli5_train_q_index)
|
||||||
|
|
||||||
|
|
||||||
|
passages, gpu_dense_index, es_client = load_indexes()
|
||||||
|
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
|
||||||
|
eli5_train, eli5_train_q_index = load_train_data()
|
||||||
|
|
||||||
|
|
||||||
|
def find_nearest_training(question, n_results=10):
|
||||||
|
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
|
||||||
|
D, I = eli5_train_q_index.search(q_rep, n_results)
|
||||||
|
nn_examples = [eli5_train[int(i)] for i in I[0]]
|
||||||
|
return nn_examples
|
||||||
|
|
||||||
|
|
||||||
|
def make_support(question, source="wiki40b", method="dense", n_results=10):
|
||||||
|
if source == "none":
|
||||||
|
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
|
||||||
|
else:
|
||||||
|
if method == "dense":
|
||||||
|
support_doc, hit_lst = query_qa_dense_index(
|
||||||
|
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
support_doc, hit_lst = query_es_index(
|
||||||
|
question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results,
|
||||||
|
)
|
||||||
|
support_list = [
|
||||||
|
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
|
||||||
|
]
|
||||||
|
question_doc = "question: {} context: {}".format(question, support_doc)
|
||||||
|
return question_doc, support_list
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)})
|
||||||
|
def answer_question(
|
||||||
|
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
answer = qa_s2s_generate(
|
||||||
|
question_doc,
|
||||||
|
s2s_model,
|
||||||
|
s2s_tokenizer,
|
||||||
|
num_answers=1,
|
||||||
|
num_beams=n_beams,
|
||||||
|
min_len=min_len,
|
||||||
|
max_len=max_len,
|
||||||
|
do_sample=sampling,
|
||||||
|
temp=temp,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=None,
|
||||||
|
max_input_length=1024,
|
||||||
|
device="cuda:0",
|
||||||
|
)[0]
|
||||||
|
return (answer, support_list)
|
||||||
|
|
||||||
|
|
||||||
|
st.title("Long Form Question Answering with ELI5")
|
||||||
|
|
||||||
|
# Start sidebar
|
||||||
|
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
|
||||||
|
header_full = """
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
.img-container {
|
||||||
|
padding-left: 90px;
|
||||||
|
padding-right: 90px;
|
||||||
|
padding-top: 50px;
|
||||||
|
padding-bottom: 50px;
|
||||||
|
background-color: #f0f3f9;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<span class="img-container"> <!-- Inline parent element -->
|
||||||
|
%s
|
||||||
|
</span>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""" % (
|
||||||
|
header_html,
|
||||||
|
)
|
||||||
|
st.sidebar.markdown(
|
||||||
|
header_full, unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long Form QA with ELI5 and Wikipedia
|
||||||
|
description = """
|
||||||
|
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
|
||||||
|
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
|
||||||
|
a pre-processed fixed snapshot of Wikipedia.
|
||||||
|
"""
|
||||||
|
st.sidebar.markdown(description, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
action_list = [
|
||||||
|
"Answer the question",
|
||||||
|
"View the retrieved document only",
|
||||||
|
"View the most similar ELI5 question and answer",
|
||||||
|
"Show me everything, please!",
|
||||||
|
]
|
||||||
|
demo_options = st.sidebar.checkbox("Demo options")
|
||||||
|
if demo_options:
|
||||||
|
action_st = st.sidebar.selectbox("", action_list, index=3,)
|
||||||
|
action = action_list.index(action_st)
|
||||||
|
show_type = st.sidebar.selectbox("", ["Show full text of passages", "Show passage section titles"], index=0,)
|
||||||
|
show_passages = show_type == "Show full text of passages"
|
||||||
|
else:
|
||||||
|
action = 3
|
||||||
|
show_passages = True
|
||||||
|
|
||||||
|
retrieval_options = st.sidebar.checkbox("Retrieval options")
|
||||||
|
if retrieval_options:
|
||||||
|
retriever_info = """
|
||||||
|
### Information retriever options
|
||||||
|
|
||||||
|
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
|
||||||
|
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
|
||||||
|
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
|
||||||
|
"""
|
||||||
|
st.sidebar.markdown(retriever_info)
|
||||||
|
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
|
||||||
|
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
|
||||||
|
else:
|
||||||
|
wiki_source = "wiki40b"
|
||||||
|
index_type = "dense"
|
||||||
|
|
||||||
|
sampled = "beam"
|
||||||
|
n_beams = 2
|
||||||
|
min_len = 64
|
||||||
|
max_len = 256
|
||||||
|
top_p = None
|
||||||
|
temp = None
|
||||||
|
generate_options = st.sidebar.checkbox("Generation options")
|
||||||
|
if generate_options:
|
||||||
|
generate_info = """
|
||||||
|
### Answer generation options
|
||||||
|
|
||||||
|
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
|
||||||
|
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
|
||||||
|
**beam** search, or **sample** from the decoder's output probabilities.
|
||||||
|
"""
|
||||||
|
st.sidebar.markdown(generate_info)
|
||||||
|
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
|
||||||
|
min_len = st.sidebar.slider(
|
||||||
|
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
|
||||||
|
)
|
||||||
|
max_len = st.sidebar.slider(
|
||||||
|
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
|
||||||
|
)
|
||||||
|
if sampled == "beam":
|
||||||
|
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
|
||||||
|
else:
|
||||||
|
top_p = st.sidebar.slider(
|
||||||
|
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
|
||||||
|
)
|
||||||
|
temp = st.sidebar.slider(
|
||||||
|
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
|
||||||
|
)
|
||||||
|
n_beams = None
|
||||||
|
|
||||||
|
# start main text
|
||||||
|
questions_list = [
|
||||||
|
"<MY QUESTION>",
|
||||||
|
"How do people make chocolate?",
|
||||||
|
"Why do we get a fever when we are sick?",
|
||||||
|
"How can different animals perceive different colors?",
|
||||||
|
"What is natural language processing?",
|
||||||
|
"What's the best way to treat a sunburn?",
|
||||||
|
"What exactly are vitamins ?",
|
||||||
|
"How does nuclear energy provide electricity?",
|
||||||
|
"What's the difference between viruses and bacteria?",
|
||||||
|
"Why are flutes classified as woodwinds when most of them are made out of metal ?",
|
||||||
|
"Why do people like drinking coffee even though it tastes so bad?",
|
||||||
|
"What happens when wine ages? How does it make the wine taste better?",
|
||||||
|
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
|
||||||
|
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
|
||||||
|
"How does New Zealand have so many large bird predators?",
|
||||||
|
]
|
||||||
|
question_s = st.selectbox(
|
||||||
|
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", questions_list, index=1,
|
||||||
|
)
|
||||||
|
if question_s == "<MY QUESTION>":
|
||||||
|
question = st.text_input("Enter your question here:", "")
|
||||||
|
else:
|
||||||
|
question = question_s
|
||||||
|
|
||||||
|
if st.button("Show me!"):
|
||||||
|
if action in [0, 1, 3]:
|
||||||
|
if index_type == "mixed":
|
||||||
|
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
|
||||||
|
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
|
||||||
|
support_list = []
|
||||||
|
for res_d, res_s in zip(support_list_dense, support_list_sparse):
|
||||||
|
if tuple(res_d) not in support_list:
|
||||||
|
support_list += [tuple(res_d)]
|
||||||
|
if tuple(res_s) not in support_list:
|
||||||
|
support_list += [tuple(res_s)]
|
||||||
|
support_list = support_list[:10]
|
||||||
|
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
|
||||||
|
else:
|
||||||
|
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
|
||||||
|
if action in [0, 3]:
|
||||||
|
answer, support_list = answer_question(
|
||||||
|
question_doc,
|
||||||
|
s2s_model,
|
||||||
|
s2s_tokenizer,
|
||||||
|
min_len=min_len,
|
||||||
|
max_len=int(max_len),
|
||||||
|
sampling=(sampled == "sampled"),
|
||||||
|
n_beams=n_beams,
|
||||||
|
top_p=top_p,
|
||||||
|
temp=temp,
|
||||||
|
)
|
||||||
|
st.markdown("### The model generated answer is:")
|
||||||
|
st.write(answer)
|
||||||
|
if action in [0, 1, 3] and wiki_source != "none":
|
||||||
|
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
|
||||||
|
for i, res in enumerate(support_list):
|
||||||
|
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
|
||||||
|
sec_titles = res[1].strip()
|
||||||
|
if sec_titles == "":
|
||||||
|
sections = "[{}]({})".format(res[0], wiki_url)
|
||||||
|
else:
|
||||||
|
sec_list = sec_titles.split(" & ")
|
||||||
|
sections = " & ".join(
|
||||||
|
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
|
||||||
|
)
|
||||||
|
st.markdown(
|
||||||
|
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections),
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
if show_passages:
|
||||||
|
st.write(
|
||||||
|
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
|
||||||
|
)
|
||||||
|
if action in [2, 3]:
|
||||||
|
nn_train_list = find_nearest_training(question)
|
||||||
|
train_exple = nn_train_list[0]
|
||||||
|
st.markdown(
|
||||||
|
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
|
||||||
|
)
|
||||||
|
answers_st = [
|
||||||
|
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
|
||||||
|
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
|
||||||
|
if i == 0 or sc > 2
|
||||||
|
]
|
||||||
|
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
|
||||||
|
|
||||||
|
|
||||||
|
disclaimer = """
|
||||||
|
---
|
||||||
|
|
||||||
|
**Disclaimer**
|
||||||
|
|
||||||
|
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
|
||||||
|
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
|
||||||
|
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
|
||||||
|
"""
|
||||||
|
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)
|
||||||
653
examples/longform-qa/eli5_utils.py
Normal file
653
examples/longform-qa/eli5_utils.py
Normal file
@@ -0,0 +1,653 @@
|
|||||||
|
import functools
|
||||||
|
import math
|
||||||
|
import os # noqa: F401
|
||||||
|
from random import choice, randint
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import faiss # noqa: F401
|
||||||
|
import nlp # noqa: F401
|
||||||
|
import pandas as pd
|
||||||
|
from elasticsearch import Elasticsearch # noqa: F401
|
||||||
|
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
|
||||||
|
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
|
||||||
|
pd.set_option("display.max_colwidth", None)
|
||||||
|
|
||||||
|
|
||||||
|
###############
|
||||||
|
# Sparse index
|
||||||
|
###############
|
||||||
|
def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"):
|
||||||
|
index_config = {
|
||||||
|
"settings": {
|
||||||
|
"number_of_shards": 1,
|
||||||
|
"analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
|
||||||
|
},
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
|
||||||
|
"section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
|
||||||
|
"passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
es_client.indices.create(index=index_name, body=index_config)
|
||||||
|
number_of_docs = passages_dset.num_rows
|
||||||
|
progress = tqdm(unit="docs", total=number_of_docs)
|
||||||
|
successes = 0
|
||||||
|
|
||||||
|
def passage_generator():
|
||||||
|
for passage in passages_dset:
|
||||||
|
yield passage
|
||||||
|
|
||||||
|
# create the ES index
|
||||||
|
for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):
|
||||||
|
progress.update(1)
|
||||||
|
successes += ok
|
||||||
|
print("Indexed %d documents" % (successes,))
|
||||||
|
|
||||||
|
|
||||||
|
def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20):
|
||||||
|
q = question.lower()
|
||||||
|
banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"]
|
||||||
|
q = " ".join([w for w in q.split() if w not in banned])
|
||||||
|
response = es_client.search(
|
||||||
|
index=index_name,
|
||||||
|
body={
|
||||||
|
"query": {
|
||||||
|
"multi_match": {
|
||||||
|
"query": q,
|
||||||
|
"fields": ["article_title", "section_title", "passage_text^2"],
|
||||||
|
"type": "cross_fields",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"size": 2 * n_results,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
hits = response["hits"]["hits"]
|
||||||
|
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
|
||||||
|
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
|
||||||
|
for r, hit in zip(res_list, hits):
|
||||||
|
r["passage_id"] = hit["_id"]
|
||||||
|
r["score"] = hit["_score"]
|
||||||
|
r["passage_text"] = hit["_source"]["passage_text"]
|
||||||
|
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||||
|
return support_doc, res_list
|
||||||
|
|
||||||
|
|
||||||
|
###############
|
||||||
|
# ELI5 retriever training
|
||||||
|
###############
|
||||||
|
class ELI5DatasetQARetriver(Dataset):
|
||||||
|
def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
|
||||||
|
self.data = examples_array
|
||||||
|
self.answer_thres = extra_answer_threshold
|
||||||
|
self.min_length = min_answer_length
|
||||||
|
self.training = training
|
||||||
|
self.n_samples = self.data.num_rows if n_samples is None else n_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.n_samples
|
||||||
|
|
||||||
|
def make_example(self, idx):
|
||||||
|
example = self.data[idx]
|
||||||
|
question = example["title"]
|
||||||
|
if self.training:
|
||||||
|
answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
|
||||||
|
answer_tab = choice(answers).split(" ")
|
||||||
|
start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
|
||||||
|
answer_span = " ".join(answer_tab[start_idx:])
|
||||||
|
else:
|
||||||
|
answer_span = example["answers"]["text"][0]
|
||||||
|
return (question, answer_span)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.make_example(idx % self.data.num_rows)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalQAEmbedder(torch.nn.Module):
|
||||||
|
def __init__(self, sent_encoder, dim):
|
||||||
|
super(RetrievalQAEmbedder, self).__init__()
|
||||||
|
self.sent_encoder = sent_encoder
|
||||||
|
self.output_dim = 128
|
||||||
|
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
|
||||||
|
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
|
||||||
|
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
|
||||||
|
|
||||||
|
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
|
||||||
|
# reproduces BERT forward pass with checkpointing
|
||||||
|
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
|
||||||
|
return self.sent_encoder(input_ids, attention_mask=attention_mask)[1]
|
||||||
|
else:
|
||||||
|
# prepare implicit variables
|
||||||
|
device = input_ids.device
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
|
||||||
|
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
|
||||||
|
attention_mask, input_shape, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# define function for checkpointing
|
||||||
|
def partial_encode(*inputs):
|
||||||
|
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.sent_encoder.pooler(sequence_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
# run embedding layer on everything at once
|
||||||
|
embedding_output = self.sent_encoder.embeddings(
|
||||||
|
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
|
||||||
|
)
|
||||||
|
# run encoding and pooling on one mini-batch at a time
|
||||||
|
pooled_output_list = []
|
||||||
|
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
|
||||||
|
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
|
||||||
|
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
|
||||||
|
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
|
||||||
|
pooled_output_list.append(pooled_output)
|
||||||
|
return torch.cat(pooled_output_list, dim=0)
|
||||||
|
|
||||||
|
def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
|
||||||
|
q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
|
||||||
|
return self.project_q(q_reps)
|
||||||
|
|
||||||
|
def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
|
||||||
|
a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
|
||||||
|
return self.project_a(a_reps)
|
||||||
|
|
||||||
|
def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
|
||||||
|
device = q_ids.device
|
||||||
|
q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
|
||||||
|
a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
|
||||||
|
compare_scores = torch.mm(q_reps, a_reps.t())
|
||||||
|
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
|
||||||
|
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
|
||||||
|
loss = (loss_qa + loss_aq) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
bert_model = AutoModel.from_pretrained(model_name).to(device)
|
||||||
|
# run bert_model on a dummy batch to get output dimension
|
||||||
|
d_ids = torch.LongTensor(
|
||||||
|
[[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]]
|
||||||
|
).to(device)
|
||||||
|
d_mask = torch.LongTensor([[1]]).to(device)
|
||||||
|
sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1]
|
||||||
|
qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device)
|
||||||
|
if from_file is not None:
|
||||||
|
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
|
||||||
|
qa_embedder.load_state_dict(param_dict["model"])
|
||||||
|
return tokenizer, qa_embedder
|
||||||
|
|
||||||
|
|
||||||
|
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):
|
||||||
|
q_ls = [q for q, a in qa_list]
|
||||||
|
a_ls = [a for q, a in qa_list]
|
||||||
|
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
|
||||||
|
q_ids, q_mask = (
|
||||||
|
torch.LongTensor(q_toks["input_ids"]).to(device),
|
||||||
|
torch.LongTensor(q_toks["attention_mask"]).to(device),
|
||||||
|
)
|
||||||
|
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True)
|
||||||
|
a_ids, a_mask = (
|
||||||
|
torch.LongTensor(a_toks["input_ids"]).to(device),
|
||||||
|
torch.LongTensor(a_toks["attention_mask"]).to(device),
|
||||||
|
)
|
||||||
|
return (q_ids, q_mask, a_ids, a_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0):
|
||||||
|
model.train()
|
||||||
|
# make iterator
|
||||||
|
train_sampler = RandomSampler(dataset)
|
||||||
|
model_collate_fn = functools.partial(
|
||||||
|
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||||
|
)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||||
|
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||||
|
# accumulate loss since last print
|
||||||
|
loc_steps = 0
|
||||||
|
loc_loss = 0.0
|
||||||
|
st_time = time()
|
||||||
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
q_ids, q_mask, a_ids, a_mask = batch
|
||||||
|
pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
|
||||||
|
loss = pre_loss.sum()
|
||||||
|
# optimizer
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
model.zero_grad()
|
||||||
|
# some printing within the epoch
|
||||||
|
loc_loss += loss.item()
|
||||||
|
loc_steps += 1
|
||||||
|
if step % args.print_freq == 0 or step == 1:
|
||||||
|
print(
|
||||||
|
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||||
|
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
loc_loss = 0
|
||||||
|
loc_steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0):
|
||||||
|
model.train()
|
||||||
|
model_collate_fn = functools.partial(
|
||||||
|
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||||
|
)
|
||||||
|
# make iterator
|
||||||
|
train_samplers = [RandomSampler(dataset) for dataset in dataset_list]
|
||||||
|
data_loaders = [
|
||||||
|
DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||||
|
for dataset, train_sampler in zip(dataset_list, train_samplers)
|
||||||
|
]
|
||||||
|
iterators = [iter(dloader) for dloader in data_loaders]
|
||||||
|
joint_iter = zip(*iterators)
|
||||||
|
# accumulate loss since last print
|
||||||
|
loc_steps = 0
|
||||||
|
loc_loss = 0.0
|
||||||
|
st_time = time()
|
||||||
|
for step, (batches,) in enumerate(zip(joint_iter)):
|
||||||
|
for batch in batches:
|
||||||
|
q_ids, q_mask, a_ids, a_mask = batch
|
||||||
|
loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
|
||||||
|
# optimizer
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
model.zero_grad()
|
||||||
|
# some printing within the epoch
|
||||||
|
loc_loss += loss.item()
|
||||||
|
loc_steps += 1
|
||||||
|
if step % args.print_freq == 0:
|
||||||
|
print(
|
||||||
|
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||||
|
e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
loc_loss = 0
|
||||||
|
loc_steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_qa_retriever(model, dataset, tokenizer, args):
|
||||||
|
model.eval()
|
||||||
|
# make iterator
|
||||||
|
eval_sampler = SequentialSampler(dataset)
|
||||||
|
model_collate_fn = functools.partial(
|
||||||
|
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||||
|
)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn)
|
||||||
|
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||||
|
tot_loss = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
q_ids, q_mask, a_ids, a_mask = batch
|
||||||
|
loss = model(q_ids, q_mask, a_ids, a_mask)
|
||||||
|
tot_loss += loss.item()
|
||||||
|
return tot_loss / (step + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):
|
||||||
|
qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
|
||||||
|
qar_scheduler = get_linear_schedule_with_warmup(
|
||||||
|
qar_optimizer,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),
|
||||||
|
)
|
||||||
|
for e in range(qar_args.num_epochs):
|
||||||
|
train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)
|
||||||
|
m_save_dict = {
|
||||||
|
"model": qar_model.state_dict(),
|
||||||
|
"optimizer": qar_optimizer.state_dict(),
|
||||||
|
"scheduler": qar_scheduler.state_dict(),
|
||||||
|
}
|
||||||
|
print("Saving model {}".format(qar_args.model_save_name))
|
||||||
|
torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))
|
||||||
|
eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
|
||||||
|
print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))
|
||||||
|
|
||||||
|
|
||||||
|
###############
|
||||||
|
# ELI5 seq2seq model training
|
||||||
|
###############
|
||||||
|
class ELI5DatasetS2S(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True
|
||||||
|
):
|
||||||
|
self.training = training
|
||||||
|
self.data = examples_array
|
||||||
|
self.make_doc_function = make_doc_fun
|
||||||
|
self.document_cache = {} if document_cache is None else document_cache
|
||||||
|
assert not (make_doc_fun is None and document_cache is None)
|
||||||
|
# make index of specific question-answer pairs from multi-answers
|
||||||
|
if self.training:
|
||||||
|
self.qa_id_list = [
|
||||||
|
(i, j)
|
||||||
|
for i, qa in enumerate(self.data)
|
||||||
|
for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"]))
|
||||||
|
if j == 0 or sc >= extra_answer_threshold
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.qa_id_list)
|
||||||
|
|
||||||
|
def make_example(self, idx):
|
||||||
|
i, j = self.qa_id_list[idx]
|
||||||
|
example = self.data[i]
|
||||||
|
question = example["title"] + " " + example["selftext"]
|
||||||
|
answer = example["answers"]["text"][j]
|
||||||
|
q_id = example["q_id"]
|
||||||
|
if self.make_doc_function is not None:
|
||||||
|
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
|
||||||
|
document = self.document_cache[q_id]
|
||||||
|
in_st = "question: {} context: {}".format(
|
||||||
|
question.lower().replace(" --t--", "").strip(), document.lower().strip(),
|
||||||
|
)
|
||||||
|
out_st = answer
|
||||||
|
return (in_st, out_st)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.make_example(idx)
|
||||||
|
|
||||||
|
|
||||||
|
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||||
|
if from_file is not None:
|
||||||
|
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
|
||||||
|
model.load_state_dict(param_dict["model"])
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
|
||||||
|
q_ls = [q for q, a in qa_list]
|
||||||
|
a_ls = [a for q, a in qa_list]
|
||||||
|
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
|
||||||
|
q_ids, q_mask = (
|
||||||
|
torch.LongTensor(q_toks["input_ids"]).to(device),
|
||||||
|
torch.LongTensor(q_toks["attention_mask"]).to(device),
|
||||||
|
)
|
||||||
|
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
|
||||||
|
a_ids, a_mask = (
|
||||||
|
torch.LongTensor(a_toks["input_ids"]).to(device),
|
||||||
|
torch.LongTensor(a_toks["attention_mask"]).to(device),
|
||||||
|
)
|
||||||
|
lm_labels = a_ids[:, 1:].contiguous().clone()
|
||||||
|
lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
|
||||||
|
model_inputs = {
|
||||||
|
"input_ids": q_ids,
|
||||||
|
"attention_mask": q_mask,
|
||||||
|
"decoder_input_ids": a_ids[:, :-1].contiguous(),
|
||||||
|
"lm_labels": lm_labels,
|
||||||
|
}
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
|
||||||
|
model.train()
|
||||||
|
# make iterator
|
||||||
|
if curriculum:
|
||||||
|
train_sampler = SequentialSampler(dataset)
|
||||||
|
else:
|
||||||
|
train_sampler = RandomSampler(dataset)
|
||||||
|
model_collate_fn = functools.partial(
|
||||||
|
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||||
|
)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||||
|
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||||
|
# accumulate loss since last print
|
||||||
|
loc_steps = 0
|
||||||
|
loc_loss = 0.0
|
||||||
|
st_time = time()
|
||||||
|
for step, batch_inputs in enumerate(epoch_iterator):
|
||||||
|
pre_loss = model(**batch_inputs)[0]
|
||||||
|
loss = pre_loss.sum() / pre_loss.shape[0]
|
||||||
|
loss.backward()
|
||||||
|
# optimizer
|
||||||
|
if step % args.backward_freq == 0:
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
model.zero_grad()
|
||||||
|
# some printing within the epoch
|
||||||
|
loc_loss += loss.item()
|
||||||
|
loc_steps += 1
|
||||||
|
if step % args.print_freq == 0 or step == 1:
|
||||||
|
print(
|
||||||
|
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||||
|
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
loc_loss = 0
|
||||||
|
loc_steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
|
||||||
|
model.eval()
|
||||||
|
# make iterator
|
||||||
|
train_sampler = SequentialSampler(dataset)
|
||||||
|
model_collate_fn = functools.partial(
|
||||||
|
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
|
||||||
|
)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
|
||||||
|
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
||||||
|
# accumulate loss since last print
|
||||||
|
loc_steps = 0
|
||||||
|
loc_loss = 0.0
|
||||||
|
st_time = time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for step, batch_inputs in enumerate(epoch_iterator):
|
||||||
|
pre_loss = model(**batch_inputs)[0]
|
||||||
|
loss = pre_loss.sum() / pre_loss.shape[0]
|
||||||
|
loc_loss += loss.item()
|
||||||
|
loc_steps += 1
|
||||||
|
if step % args.print_freq == 0:
|
||||||
|
print(
|
||||||
|
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||||
|
step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))
|
||||||
|
|
||||||
|
|
||||||
|
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
|
||||||
|
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
|
||||||
|
s2s_scheduler = get_linear_schedule_with_warmup(
|
||||||
|
s2s_optimizer,
|
||||||
|
num_warmup_steps=400,
|
||||||
|
num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
|
||||||
|
)
|
||||||
|
for e in range(s2s_args.num_epochs):
|
||||||
|
train_qa_s2s_epoch(
|
||||||
|
qa_s2s_model,
|
||||||
|
s2s_train_dset,
|
||||||
|
qa_s2s_tokenizer,
|
||||||
|
s2s_optimizer,
|
||||||
|
s2s_scheduler,
|
||||||
|
s2s_args,
|
||||||
|
e,
|
||||||
|
curriculum=(e == 0),
|
||||||
|
)
|
||||||
|
m_save_dict = {
|
||||||
|
"model": qa_s2s_model.state_dict(),
|
||||||
|
"optimizer": s2s_optimizer.state_dict(),
|
||||||
|
"scheduler": s2s_scheduler.state_dict(),
|
||||||
|
}
|
||||||
|
print("Saving model {}".format(s2s_args.model_save_name))
|
||||||
|
eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
|
||||||
|
torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))
|
||||||
|
|
||||||
|
|
||||||
|
# generate answer from input "question: ... context: <p> ..."
|
||||||
|
def qa_s2s_generate(
|
||||||
|
question_doc,
|
||||||
|
qa_s2s_model,
|
||||||
|
qa_s2s_tokenizer,
|
||||||
|
num_answers=1,
|
||||||
|
num_beams=None,
|
||||||
|
min_len=64,
|
||||||
|
max_len=256,
|
||||||
|
do_sample=False,
|
||||||
|
temp=1.0,
|
||||||
|
top_p=None,
|
||||||
|
top_k=None,
|
||||||
|
max_input_length=512,
|
||||||
|
device="cuda:0",
|
||||||
|
):
|
||||||
|
model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
|
||||||
|
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
|
||||||
|
generated_ids = qa_s2s_model.generate(
|
||||||
|
input_ids=model_inputs["input_ids"],
|
||||||
|
attention_mask=model_inputs["attention_mask"],
|
||||||
|
min_length=min_len,
|
||||||
|
max_length=max_len,
|
||||||
|
do_sample=do_sample,
|
||||||
|
early_stopping=True,
|
||||||
|
num_beams=1 if do_sample else n_beams,
|
||||||
|
temperature=temp,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
eos_token_id=qa_s2s_tokenizer.eos_token_id,
|
||||||
|
no_repeat_ngram_size=3,
|
||||||
|
num_return_sequences=num_answers,
|
||||||
|
decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
|
||||||
|
)
|
||||||
|
return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]
|
||||||
|
|
||||||
|
|
||||||
|
###############
|
||||||
|
# ELI5-trained retrieval model usage
|
||||||
|
###############
|
||||||
|
def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):
|
||||||
|
a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True)
|
||||||
|
a_ids, a_mask = (
|
||||||
|
torch.LongTensor(a_toks["input_ids"]).to(device),
|
||||||
|
torch.LongTensor(a_toks["attention_mask"]).to(device),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float)
|
||||||
|
return a_reps.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):
|
||||||
|
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True)
|
||||||
|
q_ids, q_mask = (
|
||||||
|
torch.LongTensor(q_toks["input_ids"]).to(device),
|
||||||
|
torch.LongTensor(q_toks["attention_mask"]).to(device),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float)
|
||||||
|
return q_reps.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def make_qa_dense_index(
|
||||||
|
qa_embedder,
|
||||||
|
tokenizer,
|
||||||
|
passages_dset,
|
||||||
|
batch_size=512,
|
||||||
|
max_length=128,
|
||||||
|
index_name="kilt_passages_reps.dat",
|
||||||
|
dtype="float32",
|
||||||
|
device="cuda:0",
|
||||||
|
):
|
||||||
|
st_time = time()
|
||||||
|
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
|
||||||
|
n_batches = math.ceil(passages_dset.num_rows / batch_size)
|
||||||
|
for i in range(n_batches):
|
||||||
|
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
|
||||||
|
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
|
||||||
|
fp[i * batch_size : (i + 1) * batch_size] = reps
|
||||||
|
if i % 50 == 0:
|
||||||
|
print(i, time() - st_time)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False):
|
||||||
|
total_retriever_time = 0.0
|
||||||
|
total_retriever_score = 0.0
|
||||||
|
st_time = time()
|
||||||
|
for i, (question, answer) in enumerate(qa_list):
|
||||||
|
r_time = time()
|
||||||
|
retrieved_passages = retriever_func(question, n_ret)
|
||||||
|
total_retriever_time += time() - r_time
|
||||||
|
total_retriever_score += scoring_func(retrieved_passages, answer)
|
||||||
|
if verbose and ((i + 1) % 500 == 0 or i <= 1):
|
||||||
|
print(
|
||||||
|
"{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format(
|
||||||
|
i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)}
|
||||||
|
|
||||||
|
|
||||||
|
# build a support document for the question out of Wikipedia snippets
|
||||||
|
def query_qa_dense_index(
|
||||||
|
question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0"
|
||||||
|
):
|
||||||
|
q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device)
|
||||||
|
D, I = wiki_index.search(q_rep, 2 * n_results)
|
||||||
|
res_passages = [wiki_passages[int(i)] for i in I[0]]
|
||||||
|
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
|
||||||
|
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||||
|
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||||
|
for r, sc in zip(res_list, D[0]):
|
||||||
|
r["score"] = float(sc)
|
||||||
|
return support_doc, res_list
|
||||||
|
|
||||||
|
|
||||||
|
def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
|
||||||
|
q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder)
|
||||||
|
D, I = wiki_index.search(q_rep, n_results)
|
||||||
|
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
|
||||||
|
support_doc_lst = [
|
||||||
|
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
||||||
|
]
|
||||||
|
all_res_lists = []
|
||||||
|
for (res_passages, dl) in zip(res_passages_lst, D):
|
||||||
|
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||||
|
for r, sc in zip(res_list, dl):
|
||||||
|
r["score"] = float(sc)
|
||||||
|
all_res_lists += [res_list[:]]
|
||||||
|
return support_doc_lst, all_res_lists
|
||||||
|
|
||||||
|
|
||||||
|
# find nearest neighbors of an answer or declarative text in Wikipedia snippets
|
||||||
|
def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20):
|
||||||
|
a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder)
|
||||||
|
D, I = wiki_index.search(a_rep, 2 * n_results)
|
||||||
|
res_passages = [wiki_passages[int(i)] for i in I[0]]
|
||||||
|
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
|
||||||
|
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||||
|
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||||
|
for r, sc, i in zip(res_list, D[0], I[0]):
|
||||||
|
r["passage_id"] = int(i)
|
||||||
|
r["score"] = float(sc)
|
||||||
|
return support_doc, res_list
|
||||||
|
|
||||||
|
|
||||||
|
def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
|
||||||
|
a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder)
|
||||||
|
D, I = wiki_index.search(a_reps, n_results)
|
||||||
|
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
|
||||||
|
support_doc_lst = [
|
||||||
|
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
||||||
|
]
|
||||||
|
all_res_lists = []
|
||||||
|
for (res_passages, dl, il) in zip(res_passages_lst, D, I):
|
||||||
|
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||||
|
for r, sc, i in zip(res_list, dl, il):
|
||||||
|
r["passage_id"] = int(i)
|
||||||
|
r["score"] = float(sc)
|
||||||
|
all_res_lists += [res_list[:]]
|
||||||
|
return support_doc_lst, all_res_lists
|
||||||
@@ -36,6 +36,7 @@ from .configuration_marian import MarianConfig
|
|||||||
from .configuration_mmbt import MMBTConfig
|
from .configuration_mmbt import MMBTConfig
|
||||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||||
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||||
|
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||||
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||||
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||||
@@ -130,6 +131,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
|||||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||||
from .tokenization_reformer import ReformerTokenizer
|
from .tokenization_reformer import ReformerTokenizer
|
||||||
|
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||||
from .tokenization_t5 import T5Tokenizer
|
from .tokenization_t5 import T5Tokenizer
|
||||||
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
|
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
|
||||||
@@ -356,6 +358,12 @@ if is_torch_available():
|
|||||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .modeling_retribert import (
|
||||||
|
RetriBertPreTrainedModel,
|
||||||
|
RetriBertModel,
|
||||||
|
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
from .optimization import (
|
from .optimization import (
|
||||||
AdamW,
|
AdamW,
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|||||||
from .configuration_marian import MarianConfig
|
from .configuration_marian import MarianConfig
|
||||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||||
from .configuration_reformer import ReformerConfig
|
from .configuration_reformer import ReformerConfig
|
||||||
|
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||||
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||||
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||||
@@ -64,6 +65,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
]
|
]
|
||||||
for key, value, in pretrained_map.items()
|
for key, value, in pretrained_map.items()
|
||||||
)
|
)
|
||||||
@@ -71,6 +73,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||||||
|
|
||||||
CONFIG_MAPPING = OrderedDict(
|
CONFIG_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("retribert", RetriBertConfig,),
|
||||||
("t5", T5Config,),
|
("t5", T5Config,),
|
||||||
("distilbert", DistilBertConfig,),
|
("distilbert", DistilBertConfig,),
|
||||||
("albert", AlbertConfig,),
|
("albert", AlbertConfig,),
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
"facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
"facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||||
"facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
|
"facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
|
||||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||||
|
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
112
src/transformers/configuration_retribert.py
Normal file
112
src/transformers/configuration_retribert.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" RetriBERT model configuration """
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TODO: uploadto AWS
|
||||||
|
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RetriBertConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a :class:`~transformers.RetriBertModel`.
|
||||||
|
It is used to instantiate a RetriBertModel model according to the specified arguments, defining the model
|
||||||
|
architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
||||||
|
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
||||||
|
for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (:obj:`int`, optional, defaults to 30522):
|
||||||
|
Vocabulary size of the BERT model. Defines the different tokens that
|
||||||
|
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
|
||||||
|
hidden_size (:obj:`int`, optional, defaults to 768):
|
||||||
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
|
num_hidden_layers (:obj:`int`, optional, defaults to 12):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (:obj:`int`, optional, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
intermediate_size (:obj:`int`, optional, defaults to 3072):
|
||||||
|
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler.
|
||||||
|
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||||
|
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
||||||
|
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
max_position_embeddings (:obj:`int`, optional, defaults to 512):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||||
|
type_vocab_size (:obj:`int`, optional, defaults to 2):
|
||||||
|
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
|
||||||
|
initializer_range (:obj:`float`, optional, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
share_encoders (:obj:`bool`, optional, defaults to True):
|
||||||
|
Whether to use the same Bert-type encoder for the queries and document
|
||||||
|
projection_dim (:obj:`int`, optional, defaults to 128):
|
||||||
|
Final dimension of the query and document representation after projection
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_type = "retribert"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=30522,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=8,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
layer_norm_eps=1e-12,
|
||||||
|
share_encoders=True,
|
||||||
|
projection_dim=128,
|
||||||
|
pad_token_id=0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.share_encoders = share_encoders
|
||||||
|
self.projection_dim = projection_dim
|
||||||
@@ -34,6 +34,7 @@ from .configuration_auto import (
|
|||||||
LongformerConfig,
|
LongformerConfig,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
ReformerConfig,
|
ReformerConfig,
|
||||||
|
RetriBertConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
T5Config,
|
T5Config,
|
||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
@@ -111,6 +112,7 @@ from .modeling_longformer import (
|
|||||||
from .modeling_marian import MarianMTModel
|
from .modeling_marian import MarianMTModel
|
||||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||||
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
||||||
|
from .modeling_retribert import RetriBertModel
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaForMultipleChoice,
|
RobertaForMultipleChoice,
|
||||||
@@ -151,6 +153,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
MODEL_MAPPING = OrderedDict(
|
MODEL_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
(RetriBertConfig, RetriBertModel),
|
||||||
(T5Config, T5Model),
|
(T5Config, T5Model),
|
||||||
(DistilBertConfig, DistilBertModel),
|
(DistilBertConfig, DistilBertModel),
|
||||||
(AlbertConfig, AlbertModel),
|
(AlbertConfig, AlbertModel),
|
||||||
@@ -174,6 +177,7 @@ MODEL_MAPPING = OrderedDict(
|
|||||||
|
|
||||||
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
(RetriBertConfig, RetriBertModel),
|
||||||
(T5Config, T5ForConditionalGeneration),
|
(T5Config, T5ForConditionalGeneration),
|
||||||
(DistilBertConfig, DistilBertForMaskedLM),
|
(DistilBertConfig, DistilBertForMaskedLM),
|
||||||
(AlbertConfig, AlbertForPreTraining),
|
(AlbertConfig, AlbertForPreTraining),
|
||||||
|
|||||||
185
src/transformers/modeling_retribert.py
Normal file
185
src/transformers/modeling_retribert.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
RetriBERT model
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
|
||||||
|
from .configuration_retribert import RetriBertConfig
|
||||||
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_bert import BertLayerNorm, BertModel
|
||||||
|
from .modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
|
"yjernite/retribert-base-uncased",
|
||||||
|
# See all RetriBert models at https://huggingface.co/models?filter=retribert
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
|
||||||
|
class RetriBertPreTrainedModel(PreTrainedModel):
|
||||||
|
""" An abstract class to handle weights initialization and
|
||||||
|
a simple interface for downloading and loading pretrained models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = RetriBertConfig
|
||||||
|
load_tf_weights = None
|
||||||
|
base_model_prefix = "retribert"
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
""" Initialize the weights """
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
elif isinstance(module, BertLayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
RETRIBERT_START_DOCSTRING = r"""
|
||||||
|
|
||||||
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||||
|
usage and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~transformers.RetriBertConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Bert Based model to embed queries or document for document retreival. """, RETRIBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class RetriBertModel(RetriBertPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.projection_dim = config.projection_dim
|
||||||
|
|
||||||
|
self.bert_query = BertModel(config)
|
||||||
|
self.bert_doc = None if config.share_encoders else BertModel(config)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
||||||
|
self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
||||||
|
|
||||||
|
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def embed_sentences_checkpointed(
|
||||||
|
self, input_ids, attention_mask, sent_encoder, checkpoint_batch_size=-1,
|
||||||
|
):
|
||||||
|
# reproduces BERT forward pass with checkpointing
|
||||||
|
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
|
||||||
|
return sent_encoder(input_ids, attention_mask=attention_mask)[1]
|
||||||
|
else:
|
||||||
|
# prepare implicit variables
|
||||||
|
device = input_ids.device
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
head_mask = [None] * sent_encoder.config.num_hidden_layers
|
||||||
|
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
|
||||||
|
attention_mask, input_shape, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# define function for cehckpointing
|
||||||
|
def partial_encode(*inputs):
|
||||||
|
encoder_outputs = sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = sent_encoder.pooler(sequence_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
# run embedding layer on everything at once
|
||||||
|
embedding_output = sent_encoder.embeddings(
|
||||||
|
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
|
||||||
|
)
|
||||||
|
# run encoding and pooling on one mini-batch at a time
|
||||||
|
pooled_output_list = []
|
||||||
|
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
|
||||||
|
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
|
||||||
|
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
|
||||||
|
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
|
||||||
|
pooled_output_list.append(pooled_output)
|
||||||
|
return torch.cat(pooled_output_list, dim=0)
|
||||||
|
|
||||||
|
def embed_questions(
|
||||||
|
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
|
||||||
|
):
|
||||||
|
q_reps = self.embed_sentences_checkpointed(input_ids, attention_mask, self.bert_query, checkpoint_batch_size,)
|
||||||
|
return self.project_query(q_reps)
|
||||||
|
|
||||||
|
def embed_answers(
|
||||||
|
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
|
||||||
|
):
|
||||||
|
a_reps = self.embed_sentences_checkpointed(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
self.bert_query if self.bert_doc is None else self.bert_doc,
|
||||||
|
checkpoint_batch_size,
|
||||||
|
)
|
||||||
|
return self.project_doc(a_reps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids_query (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary for the queries in a batch.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`transformers.RetriBertTokenizer`.
|
||||||
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||||
|
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask_query (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Mask to avoid performing attention on queries padding token indices.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
input_ids_doc (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary for the documents in a batch.
|
||||||
|
attention_mask_doc (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Mask to avoid performing attention on documents padding token indices.
|
||||||
|
|
||||||
|
checkpoint_batch_size (:obj:`int`, `optional`, defaults to `:obj:`-1`):
|
||||||
|
If greater than 0, uses gradient checkpointing to only compute sequence representation on checkpoint_batch_size examples at a time
|
||||||
|
on the GPU. All query representations are still compared to all document representations in the batch.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.FloatTensor` the bi-directional cross-entropy loss obtained while trying to match each query to its corresponding document
|
||||||
|
and each cocument to its corresponding query in the batch
|
||||||
|
"""
|
||||||
|
device = input_ids_query.device
|
||||||
|
q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size)
|
||||||
|
a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size)
|
||||||
|
compare_scores = torch.mm(q_reps, a_reps.t())
|
||||||
|
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
|
||||||
|
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
|
||||||
|
loss = (loss_qa + loss_aq) / 2
|
||||||
|
return loss
|
||||||
@@ -32,6 +32,7 @@ from .configuration_auto import (
|
|||||||
LongformerConfig,
|
LongformerConfig,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
ReformerConfig,
|
ReformerConfig,
|
||||||
|
RetriBertConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
T5Config,
|
T5Config,
|
||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
@@ -55,6 +56,7 @@ from .tokenization_longformer import LongformerTokenizer
|
|||||||
from .tokenization_marian import MarianTokenizer
|
from .tokenization_marian import MarianTokenizer
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||||
from .tokenization_reformer import ReformerTokenizer
|
from .tokenization_reformer import ReformerTokenizer
|
||||||
|
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||||
from .tokenization_t5 import T5Tokenizer
|
from .tokenization_t5 import T5Tokenizer
|
||||||
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
|
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
|
||||||
@@ -68,6 +70,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
TOKENIZER_MAPPING = OrderedDict(
|
TOKENIZER_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
||||||
(T5Config, (T5Tokenizer, None)),
|
(T5Config, (T5Tokenizer, None)),
|
||||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||||
(AlbertConfig, (AlbertTokenizer, None)),
|
(AlbertConfig, (AlbertTokenizer, None)),
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ _all_bart_models = [
|
|||||||
"facebook/bart-large-mnli",
|
"facebook/bart-large-mnli",
|
||||||
"facebook/bart-large-cnn",
|
"facebook/bart-large-cnn",
|
||||||
"facebook/bart-large-xsum",
|
"facebook/bart-large-xsum",
|
||||||
|
"yjernite/bart_eli5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
76
src/transformers/tokenization_retribert.py
Normal file
76
src/transformers/tokenization_retribert.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization classes for RetriBERT."""
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
|
"vocab_file": {
|
||||||
|
"yjernite/retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
|
"yjernite/retribert-base-uncased": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
|
"yjernite/retribert-base-uncased": {"do_lower_case": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RetriBertTokenizer(BertTokenizer):
|
||||||
|
r"""
|
||||||
|
Constructs a retribert.
|
||||||
|
|
||||||
|
:class:`~transformers.retribert is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
|
||||||
|
tokenization: punctuation splitting + wordpiece.
|
||||||
|
|
||||||
|
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
model_input_names = ["attention_mask"]
|
||||||
|
|
||||||
|
|
||||||
|
class RetriBertTokenizerFast(BertTokenizerFast):
|
||||||
|
r"""
|
||||||
|
Constructs a "Fast" RetriBertTokenizerFast (backed by HuggingFace's `tokenizers` library).
|
||||||
|
|
||||||
|
:class:`~transformers.RetriBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
|
||||||
|
tokenization: punctuation splitting + wordpiece.
|
||||||
|
|
||||||
|
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
model_input_names = ["attention_mask"]
|
||||||
Reference in New Issue
Block a user