diff --git a/docs/source/index.rst b/docs/source/index.rst index 8529712f32..12c670ed06 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -111,3 +111,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train model_doc/reformer model_doc/marian model_doc/longformer + model_doc/retribert diff --git a/docs/source/model_doc/retribert.rst b/docs/source/model_doc/retribert.rst new file mode 100644 index 0000000000..c26f61dc08 --- /dev/null +++ b/docs/source/model_doc/retribert.rst @@ -0,0 +1,39 @@ +RetriBERT +---------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~ + +The RetriBERT model was proposed in the blog post +`Explain Anything Like I'm Five: A Model for Open Domain Long Form Question Answering `__, +RetriBERT is a small model that uses either a single or pair of Bert encoders with lower-dimension projection for dense semantic indexing of text. + +Code to train and use the model can be found `here `_. + + +RetriBertConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RetriBertConfig + :members: + + +RetriBertTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RetriBertTokenizer + :members: + + +RetriBertTokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RetriBertTokenizerFast + :members: + + +RetriBertModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RetriBertModel + :members: diff --git a/examples/longform-qa/README.md b/examples/longform-qa/README.md new file mode 100644 index 0000000000..5c85a90ab9 --- /dev/null +++ b/examples/longform-qa/README.md @@ -0,0 +1,5 @@ +# Long Form Question Answering + +This folder contains the code for the Long Form Question answering [demo](http://35.226.96.115:8080/) as well as methods to train and use a fully end-to-end Long Form Question Answering system using the [🤗transformers](https://github.com/huggingface/transformers) and [🤗nlp](https://github.com/huggingface/nlp) libraries. + +You can use these mothods to train your own system by following along the associate [notebook](https://github.com/huggingface/notebooks/blob/master/longform-qa/Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb) or [blog post](https://yjernite.github.io/lfqa.html). diff --git a/examples/longform-qa/eli5_app.py b/examples/longform-qa/eli5_app.py new file mode 100644 index 0000000000..e79f1d6ed1 --- /dev/null +++ b/examples/longform-qa/eli5_app.py @@ -0,0 +1,332 @@ +import numpy as np +import torch + +import faiss +import nlp +import streamlit as st +import transformers +from elasticsearch import Elasticsearch +from eli5_utils import ( + embed_questions_for_retrieval, + make_qa_s2s_model, + qa_s2s_generate, + query_es_index, + query_qa_dense_index, +) +from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer + + +MODEL_TYPE = "bart" +LOAD_DENSE_INDEX = True + + +@st.cache(allow_output_mutation=True) +def load_models(): + if LOAD_DENSE_INDEX: + qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased") + qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0") + _ = qar_model.eval() + else: + qar_tokenizer, qar_model = (None, None) + if MODEL_TYPE == "bart": + s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5") + s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0") + save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth") + s2s_model.load_state_dict(save_dict["model"]) + _ = s2s_model.eval() + else: + s2s_tokenizer, s2s_model = make_qa_s2s_model( + model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0" + ) + return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model) + + +@st.cache(allow_output_mutation=True) +def load_indexes(): + if LOAD_DENSE_INDEX: + faiss_res = faiss.StandardGpuResources() + wiki40b_passages = nlp.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"] + wiki40b_passage_reps = np.memmap( + "wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat", + dtype="float32", + mode="r", + shape=(wiki40b_passages.num_rows, 128), + ) + wiki40b_index_flat = faiss.IndexFlatIP(128) + wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat) + wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU + else: + wiki40b_passages, wiki40b_gpu_index_flat = (None, None) + es_client = Elasticsearch([{"host": "localhost", "port": "9200"}]) + return (wiki40b_passages, wiki40b_gpu_index_flat, es_client) + + +@st.cache(allow_output_mutation=True) +def load_train_data(): + eli5 = nlp.load_dataset("eli5", name="LFQA_reddit") + eli5_train = eli5["train_eli5"] + eli5_train_q_reps = np.memmap( + "eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128) + ) + eli5_train_q_index = faiss.IndexFlatIP(128) + eli5_train_q_index.add(eli5_train_q_reps) + return (eli5_train, eli5_train_q_index) + + +passages, gpu_dense_index, es_client = load_indexes() +qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models() +eli5_train, eli5_train_q_index = load_train_data() + + +def find_nearest_training(question, n_results=10): + q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model) + D, I = eli5_train_q_index.search(q_rep, n_results) + nn_examples = [eli5_train[int(i)] for i in I[0]] + return nn_examples + + +def make_support(question, source="wiki40b", method="dense", n_results=10): + if source == "none": + support_doc, hit_lst = ("

".join(["" for _ in range(11)]).strip(), []) + else: + if method == "dense": + support_doc, hit_lst = query_qa_dense_index( + question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results + ) + else: + support_doc, hit_lst = query_es_index( + question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results, + ) + support_list = [ + (res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst + ] + question_doc = "question: {} context: {}".format(question, support_doc) + return question_doc, support_list + + +@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)}) +def answer_question( + question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8 +): + with torch.no_grad(): + answer = qa_s2s_generate( + question_doc, + s2s_model, + s2s_tokenizer, + num_answers=1, + num_beams=n_beams, + min_len=min_len, + max_len=max_len, + do_sample=sampling, + temp=temp, + top_p=top_p, + top_k=None, + max_input_length=1024, + device="cuda:0", + )[0] + return (answer, support_list) + + +st.title("Long Form Question Answering with ELI5") + +# Start sidebar +header_html = "" +header_full = """ + + + + + + + %s + + + +""" % ( + header_html, +) +st.sidebar.markdown( + header_full, unsafe_allow_html=True, +) + +# Long Form QA with ELI5 and Wikipedia +description = """ +This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html). +First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset, +a pre-processed fixed snapshot of Wikipedia. +""" +st.sidebar.markdown(description, unsafe_allow_html=True) + +action_list = [ + "Answer the question", + "View the retrieved document only", + "View the most similar ELI5 question and answer", + "Show me everything, please!", +] +demo_options = st.sidebar.checkbox("Demo options") +if demo_options: + action_st = st.sidebar.selectbox("", action_list, index=3,) + action = action_list.index(action_st) + show_type = st.sidebar.selectbox("", ["Show full text of passages", "Show passage section titles"], index=0,) + show_passages = show_type == "Show full text of passages" +else: + action = 3 + show_passages = True + +retrieval_options = st.sidebar.checkbox("Retrieval options") +if retrieval_options: + retriever_info = """ + ### Information retriever options + + The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding + trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs. + The answer is then generated by sequence to sequence model which takes the question and retrieved document as input. + """ + st.sidebar.markdown(retriever_info) + wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"]) + index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"]) +else: + wiki_source = "wiki40b" + index_type = "dense" + +sampled = "beam" +n_beams = 2 +min_len = 64 +max_len = 256 +top_p = None +temp = None +generate_options = st.sidebar.checkbox("Generation options") +if generate_options: + generate_info = """ + ### Answer generation options + + The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large) + weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with + **beam** search, or **sample** from the decoder's output probabilities. + """ + st.sidebar.markdown(generate_info) + sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"]) + min_len = st.sidebar.slider( + "Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None + ) + max_len = st.sidebar.slider( + "Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None + ) + if sampled == "beam": + n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None) + else: + top_p = st.sidebar.slider( + "Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None + ) + temp = st.sidebar.slider( + "Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None + ) + n_beams = None + +# start main text +questions_list = [ + "", + "How do people make chocolate?", + "Why do we get a fever when we are sick?", + "How can different animals perceive different colors?", + "What is natural language processing?", + "What's the best way to treat a sunburn?", + "What exactly are vitamins ?", + "How does nuclear energy provide electricity?", + "What's the difference between viruses and bacteria?", + "Why are flutes classified as woodwinds when most of them are made out of metal ?", + "Why do people like drinking coffee even though it tastes so bad?", + "What happens when wine ages? How does it make the wine taste better?", + "If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?", + "How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?", + "How does New Zealand have so many large bird predators?", +] +question_s = st.selectbox( + "What would you like to ask? ---- select to enter a new query", questions_list, index=1, +) +if question_s == "": + question = st.text_input("Enter your question here:", "") +else: + question = question_s + +if st.button("Show me!"): + if action in [0, 1, 3]: + if index_type == "mixed": + _, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10) + _, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10) + support_list = [] + for res_d, res_s in zip(support_list_dense, support_list_sparse): + if tuple(res_d) not in support_list: + support_list += [tuple(res_d)] + if tuple(res_s) not in support_list: + support_list += [tuple(res_s)] + support_list = support_list[:10] + question_doc = "

" + "

".join([res[-1] for res in support_list]) + else: + question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10) + if action in [0, 3]: + answer, support_list = answer_question( + question_doc, + s2s_model, + s2s_tokenizer, + min_len=min_len, + max_len=int(max_len), + sampling=(sampled == "sampled"), + n_beams=n_beams, + top_p=top_p, + temp=temp, + ) + st.markdown("### The model generated answer is:") + st.write(answer) + if action in [0, 1, 3] and wiki_source != "none": + st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:") + for i, res in enumerate(support_list): + wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_")) + sec_titles = res[1].strip() + if sec_titles == "": + sections = "[{}]({})".format(res[0], wiki_url) + else: + sec_list = sec_titles.split(" & ") + sections = " & ".join( + ["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list] + ) + st.markdown( + "{0:02d} - **Article**: {1:<18}
_Section_: {2}".format(i + 1, res[0], sections), + unsafe_allow_html=True, + ) + if show_passages: + st.write( + '> ' + res[-1] + "", unsafe_allow_html=True + ) + if action in [2, 3]: + nn_train_list = find_nearest_training(question) + train_exple = nn_train_list[0] + st.markdown( + "--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"]) + ) + answers_st = [ + "{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""])) + for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"])) + if i == 0 or sc > 2 + ] + st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st))) + + +disclaimer = """ +--- + +**Disclaimer** + +*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system. +Evaluating biases of such a model and ensuring factual generations are still very much open research problems. +Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.* +""" +st.sidebar.markdown(disclaimer, unsafe_allow_html=True) diff --git a/examples/longform-qa/eli5_utils.py b/examples/longform-qa/eli5_utils.py new file mode 100644 index 0000000000..2a431fd453 --- /dev/null +++ b/examples/longform-qa/eli5_utils.py @@ -0,0 +1,653 @@ +import functools +import math +import os # noqa: F401 +from random import choice, randint +from time import time + +import numpy as np +import torch +import torch.utils.checkpoint as checkpoint +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from tqdm import tqdm + +import faiss # noqa: F401 +import nlp # noqa: F401 +import pandas as pd +from elasticsearch import Elasticsearch # noqa: F401 +from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 +from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup + + +pd.set_option("display.max_colwidth", None) + + +############### +# Sparse index +############### +def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"): + index_config = { + "settings": { + "number_of_shards": 1, + "analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}}, + }, + "mappings": { + "properties": { + "article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"}, + "section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"}, + "passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"}, + } + }, + } + es_client.indices.create(index=index_name, body=index_config) + number_of_docs = passages_dset.num_rows + progress = tqdm(unit="docs", total=number_of_docs) + successes = 0 + + def passage_generator(): + for passage in passages_dset: + yield passage + + # create the ES index + for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),): + progress.update(1) + successes += ok + print("Indexed %d documents" % (successes,)) + + +def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20): + q = question.lower() + banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"] + q = " ".join([w for w in q.split() if w not in banned]) + response = es_client.search( + index=index_name, + body={ + "query": { + "multi_match": { + "query": q, + "fields": ["article_title", "section_title", "passage_text^2"], + "type": "cross_fields", + } + }, + "size": 2 * n_results, + }, + ) + hits = response["hits"]["hits"] + support_doc = "

" + "

".join([hit["_source"]["passage_text"] for hit in hits]) + res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits] + for r, hit in zip(res_list, hits): + r["passage_id"] = hit["_id"] + r["score"] = hit["_score"] + r["passage_text"] = hit["_source"]["passage_text"] + res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] + return support_doc, res_list + + +############### +# ELI5 retriever training +############### +class ELI5DatasetQARetriver(Dataset): + def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None): + self.data = examples_array + self.answer_thres = extra_answer_threshold + self.min_length = min_answer_length + self.training = training + self.n_samples = self.data.num_rows if n_samples is None else n_samples + + def __len__(self): + return self.n_samples + + def make_example(self, idx): + example = self.data[idx] + question = example["title"] + if self.training: + answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))] + answer_tab = choice(answers).split(" ") + start_idx = randint(0, max(0, len(answer_tab) - self.min_length)) + answer_span = " ".join(answer_tab[start_idx:]) + else: + answer_span = example["answers"]["text"][0] + return (question, answer_span) + + def __getitem__(self, idx): + return self.make_example(idx % self.data.num_rows) + + +class RetrievalQAEmbedder(torch.nn.Module): + def __init__(self, sent_encoder, dim): + super(RetrievalQAEmbedder, self).__init__() + self.sent_encoder = sent_encoder + self.output_dim = 128 + self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False) + self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False) + self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean") + + def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1): + # reproduces BERT forward pass with checkpointing + if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size: + return self.sent_encoder(input_ids, attention_mask=attention_mask)[1] + else: + # prepare implicit variables + device = input_ids.device + input_shape = input_ids.size() + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + head_mask = [None] * self.sent_encoder.config.num_hidden_layers + extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask( + attention_mask, input_shape, device + ) + + # define function for checkpointing + def partial_encode(*inputs): + encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,) + sequence_output = encoder_outputs[0] + pooled_output = self.sent_encoder.pooler(sequence_output) + return pooled_output + + # run embedding layer on everything at once + embedding_output = self.sent_encoder.embeddings( + input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None + ) + # run encoding and pooling on one mini-batch at a time + pooled_output_list = [] + for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)): + b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] + b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] + pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask) + pooled_output_list.append(pooled_output) + return torch.cat(pooled_output_list, dim=0) + + def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1): + q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size) + return self.project_q(q_reps) + + def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1): + a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size) + return self.project_a(a_reps) + + def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1): + device = q_ids.device + q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size) + a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size) + compare_scores = torch.mm(q_reps, a_reps.t()) + loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device)) + loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device)) + loss = (loss_qa + loss_aq) / 2 + return loss + + +def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"): + tokenizer = AutoTokenizer.from_pretrained(model_name) + bert_model = AutoModel.from_pretrained(model_name).to(device) + # run bert_model on a dummy batch to get output dimension + d_ids = torch.LongTensor( + [[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]] + ).to(device) + d_mask = torch.LongTensor([[1]]).to(device) + sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1] + qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device) + if from_file is not None: + param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states + qa_embedder.load_state_dict(param_dict["model"]) + return tokenizer, qa_embedder + + +def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"): + q_ls = [q for q, a in qa_list] + a_ls = [a for q, a in qa_list] + q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True) + q_ids, q_mask = ( + torch.LongTensor(q_toks["input_ids"]).to(device), + torch.LongTensor(q_toks["attention_mask"]).to(device), + ) + a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True) + a_ids, a_mask = ( + torch.LongTensor(a_toks["input_ids"]).to(device), + torch.LongTensor(a_toks["attention_mask"]).to(device), + ) + return (q_ids, q_mask, a_ids, a_mask) + + +def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0): + model.train() + # make iterator + train_sampler = RandomSampler(dataset) + model_collate_fn = functools.partial( + make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" + ) + data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) + epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) + # accumulate loss since last print + loc_steps = 0 + loc_loss = 0.0 + st_time = time() + for step, batch in enumerate(epoch_iterator): + q_ids, q_mask, a_ids, a_mask = batch + pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size) + loss = pre_loss.sum() + # optimizer + loss.backward() + optimizer.step() + scheduler.step() + model.zero_grad() + # some printing within the epoch + loc_loss += loss.item() + loc_steps += 1 + if step % args.print_freq == 0 or step == 1: + print( + "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( + e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time, + ) + ) + loc_loss = 0 + loc_steps = 0 + + +def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0): + model.train() + model_collate_fn = functools.partial( + make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" + ) + # make iterator + train_samplers = [RandomSampler(dataset) for dataset in dataset_list] + data_loaders = [ + DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) + for dataset, train_sampler in zip(dataset_list, train_samplers) + ] + iterators = [iter(dloader) for dloader in data_loaders] + joint_iter = zip(*iterators) + # accumulate loss since last print + loc_steps = 0 + loc_loss = 0.0 + st_time = time() + for step, (batches,) in enumerate(zip(joint_iter)): + for batch in batches: + q_ids, q_mask, a_ids, a_mask = batch + loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size) + # optimizer + loss.backward() + optimizer.step() + scheduler.step() + model.zero_grad() + # some printing within the epoch + loc_loss += loss.item() + loc_steps += 1 + if step % args.print_freq == 0: + print( + "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( + e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time, + ) + ) + loc_loss = 0 + loc_steps = 0 + + +def evaluate_qa_retriever(model, dataset, tokenizer, args): + model.eval() + # make iterator + eval_sampler = SequentialSampler(dataset) + model_collate_fn = functools.partial( + make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" + ) + data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn) + epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) + tot_loss = 0.0 + with torch.no_grad(): + for step, batch in enumerate(epoch_iterator): + q_ids, q_mask, a_ids, a_mask = batch + loss = model(q_ids, q_mask, a_ids, a_mask) + tot_loss += loss.item() + return tot_loss / (step + 1) + + +def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args): + qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8) + qar_scheduler = get_linear_schedule_with_warmup( + qar_optimizer, + num_warmup_steps=100, + num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size), + ) + for e in range(qar_args.num_epochs): + train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e) + m_save_dict = { + "model": qar_model.state_dict(), + "optimizer": qar_optimizer.state_dict(), + "scheduler": qar_scheduler.state_dict(), + } + print("Saving model {}".format(qar_args.model_save_name)) + torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e)) + eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args) + print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss)) + + +############### +# ELI5 seq2seq model training +############### +class ELI5DatasetS2S(Dataset): + def __init__( + self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True + ): + self.training = training + self.data = examples_array + self.make_doc_function = make_doc_fun + self.document_cache = {} if document_cache is None else document_cache + assert not (make_doc_fun is None and document_cache is None) + # make index of specific question-answer pairs from multi-answers + if self.training: + self.qa_id_list = [ + (i, j) + for i, qa in enumerate(self.data) + for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"])) + if j == 0 or sc >= extra_answer_threshold + ] + else: + self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)] + + def __len__(self): + return len(self.qa_id_list) + + def make_example(self, idx): + i, j = self.qa_id_list[idx] + example = self.data[i] + question = example["title"] + " " + example["selftext"] + answer = example["answers"]["text"][j] + q_id = example["q_id"] + if self.make_doc_function is not None: + self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"])) + document = self.document_cache[q_id] + in_st = "question: {} context: {}".format( + question.lower().replace(" --t--", "").strip(), document.lower().strip(), + ) + out_st = answer + return (in_st, out_st) + + def __getitem__(self, idx): + return self.make_example(idx) + + +def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"): + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) + if from_file is not None: + param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states + model.load_state_dict(param_dict["model"]) + return tokenizer, model + + +def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"): + q_ls = [q for q, a in qa_list] + a_ls = [a for q, a in qa_list] + q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True) + q_ids, q_mask = ( + torch.LongTensor(q_toks["input_ids"]).to(device), + torch.LongTensor(q_toks["attention_mask"]).to(device), + ) + a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True) + a_ids, a_mask = ( + torch.LongTensor(a_toks["input_ids"]).to(device), + torch.LongTensor(a_toks["attention_mask"]).to(device), + ) + lm_labels = a_ids[:, 1:].contiguous().clone() + lm_labels[a_mask[:, 1:].contiguous() == 0] = -100 + model_inputs = { + "input_ids": q_ids, + "attention_mask": q_mask, + "decoder_input_ids": a_ids[:, :-1].contiguous(), + "lm_labels": lm_labels, + } + return model_inputs + + +def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False): + model.train() + # make iterator + if curriculum: + train_sampler = SequentialSampler(dataset) + else: + train_sampler = RandomSampler(dataset) + model_collate_fn = functools.partial( + make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" + ) + data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) + epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) + # accumulate loss since last print + loc_steps = 0 + loc_loss = 0.0 + st_time = time() + for step, batch_inputs in enumerate(epoch_iterator): + pre_loss = model(**batch_inputs)[0] + loss = pre_loss.sum() / pre_loss.shape[0] + loss.backward() + # optimizer + if step % args.backward_freq == 0: + optimizer.step() + scheduler.step() + model.zero_grad() + # some printing within the epoch + loc_loss += loss.item() + loc_steps += 1 + if step % args.print_freq == 0 or step == 1: + print( + "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( + e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time, + ) + ) + loc_loss = 0 + loc_steps = 0 + + +def eval_qa_s2s_epoch(model, dataset, tokenizer, args): + model.eval() + # make iterator + train_sampler = SequentialSampler(dataset) + model_collate_fn = functools.partial( + make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" + ) + data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) + epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) + # accumulate loss since last print + loc_steps = 0 + loc_loss = 0.0 + st_time = time() + with torch.no_grad(): + for step, batch_inputs in enumerate(epoch_iterator): + pre_loss = model(**batch_inputs)[0] + loss = pre_loss.sum() / pre_loss.shape[0] + loc_loss += loss.item() + loc_steps += 1 + if step % args.print_freq == 0: + print( + "{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( + step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time, + ) + ) + print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,)) + + +def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args): + s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8) + s2s_scheduler = get_linear_schedule_with_warmup( + s2s_optimizer, + num_warmup_steps=400, + num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size), + ) + for e in range(s2s_args.num_epochs): + train_qa_s2s_epoch( + qa_s2s_model, + s2s_train_dset, + qa_s2s_tokenizer, + s2s_optimizer, + s2s_scheduler, + s2s_args, + e, + curriculum=(e == 0), + ) + m_save_dict = { + "model": qa_s2s_model.state_dict(), + "optimizer": s2s_optimizer.state_dict(), + "scheduler": s2s_scheduler.state_dict(), + } + print("Saving model {}".format(s2s_args.model_save_name)) + eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args) + torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e)) + + +# generate answer from input "question: ... context:

..." +def qa_s2s_generate( + question_doc, + qa_s2s_model, + qa_s2s_tokenizer, + num_answers=1, + num_beams=None, + min_len=64, + max_len=256, + do_sample=False, + temp=1.0, + top_p=None, + top_k=None, + max_input_length=512, + device="cuda:0", +): + model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,) + n_beams = num_answers if num_beams is None else max(num_beams, num_answers) + generated_ids = qa_s2s_model.generate( + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], + min_length=min_len, + max_length=max_len, + do_sample=do_sample, + early_stopping=True, + num_beams=1 if do_sample else n_beams, + temperature=temp, + top_k=top_k, + top_p=top_p, + eos_token_id=qa_s2s_tokenizer.eos_token_id, + no_repeat_ngram_size=3, + num_return_sequences=num_answers, + decoder_start_token_id=qa_s2s_tokenizer.bos_token_id, + ) + return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids] + + +############### +# ELI5-trained retrieval model usage +############### +def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"): + a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True) + a_ids, a_mask = ( + torch.LongTensor(a_toks["input_ids"]).to(device), + torch.LongTensor(a_toks["attention_mask"]).to(device), + ) + with torch.no_grad(): + a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float) + return a_reps.numpy() + + +def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"): + q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True) + q_ids, q_mask = ( + torch.LongTensor(q_toks["input_ids"]).to(device), + torch.LongTensor(q_toks["attention_mask"]).to(device), + ) + with torch.no_grad(): + q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float) + return q_reps.numpy() + + +def make_qa_dense_index( + qa_embedder, + tokenizer, + passages_dset, + batch_size=512, + max_length=128, + index_name="kilt_passages_reps.dat", + dtype="float32", + device="cuda:0", +): + st_time = time() + fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128)) + n_batches = math.ceil(passages_dset.num_rows / batch_size) + for i in range(n_batches): + passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]] + reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device) + fp[i * batch_size : (i + 1) * batch_size] = reps + if i % 50 == 0: + print(i, time() - st_time) + + +def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False): + total_retriever_time = 0.0 + total_retriever_score = 0.0 + st_time = time() + for i, (question, answer) in enumerate(qa_list): + r_time = time() + retrieved_passages = retriever_func(question, n_ret) + total_retriever_time += time() - r_time + total_retriever_score += scoring_func(retrieved_passages, answer) + if verbose and ((i + 1) % 500 == 0 or i <= 1): + print( + "{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format( + i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time + ) + ) + return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)} + + +# build a support document for the question out of Wikipedia snippets +def query_qa_dense_index( + question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0" +): + q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device) + D, I = wiki_index.search(q_rep, 2 * n_results) + res_passages = [wiki_passages[int(i)] for i in I[0]] + support_doc = "

" + "

".join([p["passage_text"] for p in res_passages]) + res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] + res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] + for r, sc in zip(res_list, D[0]): + r["score"] = float(sc) + return support_doc, res_list + + +def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10): + q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder) + D, I = wiki_index.search(q_rep, n_results) + res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I] + support_doc_lst = [ + "

" + "

".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst + ] + all_res_lists = [] + for (res_passages, dl) in zip(res_passages_lst, D): + res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] + for r, sc in zip(res_list, dl): + r["score"] = float(sc) + all_res_lists += [res_list[:]] + return support_doc_lst, all_res_lists + + +# find nearest neighbors of an answer or declarative text in Wikipedia snippets +def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20): + a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder) + D, I = wiki_index.search(a_rep, 2 * n_results) + res_passages = [wiki_passages[int(i)] for i in I[0]] + support_doc = "

" + "

".join([p["passage_text"] for p in res_passages]) + res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] + res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] + for r, sc, i in zip(res_list, D[0], I[0]): + r["passage_id"] = int(i) + r["score"] = float(sc) + return support_doc, res_list + + +def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10): + a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder) + D, I = wiki_index.search(a_reps, n_results) + res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I] + support_doc_lst = [ + "

" + "

".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst + ] + all_res_lists = [] + for (res_passages, dl, il) in zip(res_passages_lst, D, I): + res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] + for r, sc, i in zip(res_list, dl, il): + r["passage_id"] = int(i) + r["score"] = float(sc) + all_res_lists += [res_list[:]] + return support_doc_lst, all_res_lists diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3c4d568b6a..f0181a0860 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -36,6 +36,7 @@ from .configuration_marian import MarianConfig from .configuration_mmbt import MMBTConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig +from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig @@ -130,6 +131,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_reformer import ReformerTokenizer +from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_t5 import T5Tokenizer from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast @@ -356,6 +358,12 @@ if is_torch_available(): LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ) + from .modeling_retribert import ( + RetriBertPreTrainedModel, + RetriBertModel, + RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + ) + # Optimization from .optimization import ( AdamW, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index e9ba1de597..c3806c12d5 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, from .configuration_marian import MarianConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_reformer import ReformerConfig +from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig @@ -64,6 +65,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, + RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ] for key, value, in pretrained_map.items() ) @@ -71,6 +73,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( CONFIG_MAPPING = OrderedDict( [ + ("retribert", RetriBertConfig,), ("t5", T5Config,), ("distilbert", DistilBertConfig,), ("albert", AlbertConfig,), diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 3b9778ff23..00f3337b2e 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -28,6 +28,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { "facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", "facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json", "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", + "yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json", } diff --git a/src/transformers/configuration_retribert.py b/src/transformers/configuration_retribert.py new file mode 100644 index 0000000000..882c7c80cd --- /dev/null +++ b/src/transformers/configuration_retribert.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" RetriBERT model configuration """ + + +import logging + +from .configuration_utils import PretrainedConfig + + +logger = logging.getLogger(__name__) + +# TODO: uploadto AWS +RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", +} + + +class RetriBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.RetriBertModel`. + It is used to instantiate a RetriBertModel model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + + Args: + vocab_size (:obj:`int`, optional, defaults to 30522): + Vocabulary size of the BERT model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. + hidden_size (:obj:`int`, optional, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, optional, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, optional, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): + The non-linear activation function (function or string) in the encoder and pooler. + If string, "gelu", "relu", "swish" and "gelu_new" are supported. + hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, optional, defaults to 2): + The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. + initializer_range (:obj:`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + share_encoders (:obj:`bool`, optional, defaults to True): + Whether to use the same Bert-type encoder for the queries and document + projection_dim (:obj:`int`, optional, defaults to 128): + Final dimension of the query and document representation after projection + + """ + model_type = "retribert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=8, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + share_encoders=True, + projection_dim=128, + pad_token_id=0, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.share_encoders = share_encoders + self.projection_dim = projection_dim diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 6cfee27d44..da9281d192 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -34,6 +34,7 @@ from .configuration_auto import ( LongformerConfig, OpenAIGPTConfig, ReformerConfig, + RetriBertConfig, RobertaConfig, T5Config, TransfoXLConfig, @@ -111,6 +112,7 @@ from .modeling_longformer import ( from .modeling_marian import MarianMTModel from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_reformer import ReformerModel, ReformerModelWithLMHead +from .modeling_retribert import RetriBertModel from .modeling_roberta import ( RobertaForMaskedLM, RobertaForMultipleChoice, @@ -151,6 +153,7 @@ logger = logging.getLogger(__name__) MODEL_MAPPING = OrderedDict( [ + (RetriBertConfig, RetriBertModel), (T5Config, T5Model), (DistilBertConfig, DistilBertModel), (AlbertConfig, AlbertModel), @@ -174,6 +177,7 @@ MODEL_MAPPING = OrderedDict( MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( [ + (RetriBertConfig, RetriBertModel), (T5Config, T5ForConditionalGeneration), (DistilBertConfig, DistilBertForMaskedLM), (AlbertConfig, AlbertForPreTraining), diff --git a/src/transformers/modeling_retribert.py b/src/transformers/modeling_retribert.py new file mode 100644 index 0000000000..e0395ceb03 --- /dev/null +++ b/src/transformers/modeling_retribert.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +RetriBERT model +""" + + +import logging +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from .configuration_retribert import RetriBertConfig +from .file_utils import add_start_docstrings +from .modeling_bert import BertLayerNorm, BertModel +from .modeling_utils import PreTrainedModel + + +logger = logging.getLogger(__name__) + +RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "yjernite/retribert-base-uncased", + # See all RetriBert models at https://huggingface.co/models?filter=retribert +] + + +# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # +class RetriBertPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = RetriBertConfig + load_tf_weights = None + base_model_prefix = "retribert" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +RETRIBERT_START_DOCSTRING = r""" + + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.RetriBertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + + +@add_start_docstrings( + """Bert Based model to embed queries or document for document retreival. """, RETRIBERT_START_DOCSTRING, +) +class RetriBertModel(RetriBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.projection_dim = config.projection_dim + + self.bert_query = BertModel(config) + self.bert_doc = None if config.share_encoders else BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + + self.ce_loss = nn.CrossEntropyLoss(reduction="mean") + + self.init_weights() + + def embed_sentences_checkpointed( + self, input_ids, attention_mask, sent_encoder, checkpoint_batch_size=-1, + ): + # reproduces BERT forward pass with checkpointing + if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size: + return sent_encoder(input_ids, attention_mask=attention_mask)[1] + else: + # prepare implicit variables + device = input_ids.device + input_shape = input_ids.size() + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + head_mask = [None] * sent_encoder.config.num_hidden_layers + extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask( + attention_mask, input_shape, device + ) + + # define function for cehckpointing + def partial_encode(*inputs): + encoder_outputs = sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,) + sequence_output = encoder_outputs[0] + pooled_output = sent_encoder.pooler(sequence_output) + return pooled_output + + # run embedding layer on everything at once + embedding_output = sent_encoder.embeddings( + input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None + ) + # run encoding and pooling on one mini-batch at a time + pooled_output_list = [] + for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)): + b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] + b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] + pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask) + pooled_output_list.append(pooled_output) + return torch.cat(pooled_output_list, dim=0) + + def embed_questions( + self, input_ids, attention_mask=None, checkpoint_batch_size=-1, + ): + q_reps = self.embed_sentences_checkpointed(input_ids, attention_mask, self.bert_query, checkpoint_batch_size,) + return self.project_query(q_reps) + + def embed_answers( + self, input_ids, attention_mask=None, checkpoint_batch_size=-1, + ): + a_reps = self.embed_sentences_checkpointed( + input_ids, + attention_mask, + self.bert_query if self.bert_doc is None else self.bert_doc, + checkpoint_batch_size, + ) + return self.project_doc(a_reps) + + def forward( + self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1 + ): + r""" + Args: + input_ids_query (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary for the queries in a batch. + + Indices can be obtained using :class:`transformers.RetriBertTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask_query (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on queries padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + input_ids_doc (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary for the documents in a batch. + attention_mask_doc (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on documents padding token indices. + + checkpoint_batch_size (:obj:`int`, `optional`, defaults to `:obj:`-1`): + If greater than 0, uses gradient checkpointing to only compute sequence representation on checkpoint_batch_size examples at a time + on the GPU. All query representations are still compared to all document representations in the batch. + + Return: + :obj:`torch.FloatTensor` the bi-directional cross-entropy loss obtained while trying to match each query to its corresponding document + and each cocument to its corresponding query in the batch + """ + device = input_ids_query.device + q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size) + a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size) + compare_scores = torch.mm(q_reps, a_reps.t()) + loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device)) + loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device)) + loss = (loss_qa + loss_aq) / 2 + return loss diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 0a00779890..cdd85982fe 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -32,6 +32,7 @@ from .configuration_auto import ( LongformerConfig, OpenAIGPTConfig, ReformerConfig, + RetriBertConfig, RobertaConfig, T5Config, TransfoXLConfig, @@ -55,6 +56,7 @@ from .tokenization_longformer import LongformerTokenizer from .tokenization_marian import MarianTokenizer from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_reformer import ReformerTokenizer +from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_t5 import T5Tokenizer from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast @@ -68,6 +70,7 @@ logger = logging.getLogger(__name__) TOKENIZER_MAPPING = OrderedDict( [ + (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), (T5Config, (T5Tokenizer, None)), (DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)), (AlbertConfig, (AlbertTokenizer, None)), diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 89cfaf1cff..cf4f01c8d0 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -33,6 +33,7 @@ _all_bart_models = [ "facebook/bart-large-mnli", "facebook/bart-large-cnn", "facebook/bart-large-xsum", + "yjernite/bart_eli5", ] diff --git a/src/transformers/tokenization_retribert.py b/src/transformers/tokenization_retribert.py new file mode 100644 index 0000000000..a544d0b8b9 --- /dev/null +++ b/src/transformers/tokenization_retribert.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RetriBERT.""" + + +import logging + +from .tokenization_bert import BertTokenizer, BertTokenizerFast + + +logger = logging.getLogger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "yjernite/retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "yjernite/retribert-base-uncased": 512, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "yjernite/retribert-base-uncased": {"do_lower_case": True}, +} + + +class RetriBertTokenizer(BertTokenizer): + r""" + Constructs a retribert. + + :class:`~transformers.retribert is identical to :class:`~transformers.BertTokenizer` and runs end-to-end + tokenization: punctuation splitting + wordpiece. + + Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning + parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + model_input_names = ["attention_mask"] + + +class RetriBertTokenizerFast(BertTokenizerFast): + r""" + Constructs a "Fast" RetriBertTokenizerFast (backed by HuggingFace's `tokenizers` library). + + :class:`~transformers.RetriBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end + tokenization: punctuation splitting + wordpiece. + + Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning + parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + model_input_names = ["attention_mask"]