Black 20 release
This commit is contained in:
@@ -95,7 +95,10 @@ def make_support(question, source="wiki40b", method="dense", n_results=10):
|
||||
)
|
||||
else:
|
||||
support_doc, hit_lst = query_es_index(
|
||||
question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results,
|
||||
question,
|
||||
es_client,
|
||||
index_name="english_wiki40b_snippets_100w",
|
||||
n_results=n_results,
|
||||
)
|
||||
support_list = [
|
||||
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
|
||||
@@ -154,7 +157,8 @@ header_full = """
|
||||
header_html,
|
||||
)
|
||||
st.sidebar.markdown(
|
||||
header_full, unsafe_allow_html=True,
|
||||
header_full,
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Long Form QA with ELI5 and Wikipedia
|
||||
@@ -173,9 +177,17 @@ action_list = [
|
||||
]
|
||||
demo_options = st.sidebar.checkbox("Demo options")
|
||||
if demo_options:
|
||||
action_st = st.sidebar.selectbox("", action_list, index=3,)
|
||||
action_st = st.sidebar.selectbox(
|
||||
"",
|
||||
action_list,
|
||||
index=3,
|
||||
)
|
||||
action = action_list.index(action_st)
|
||||
show_type = st.sidebar.selectbox("", ["Show full text of passages", "Show passage section titles"], index=0,)
|
||||
show_type = st.sidebar.selectbox(
|
||||
"",
|
||||
["Show full text of passages", "Show passage section titles"],
|
||||
index=0,
|
||||
)
|
||||
show_passages = show_type == "Show full text of passages"
|
||||
else:
|
||||
action = 3
|
||||
@@ -250,7 +262,9 @@ questions_list = [
|
||||
"How does New Zealand have so many large bird predators?",
|
||||
]
|
||||
question_s = st.selectbox(
|
||||
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", questions_list, index=1,
|
||||
"What would you like to ask? ---- select <MY QUESTION> to enter a new query",
|
||||
questions_list,
|
||||
index=1,
|
||||
)
|
||||
if question_s == "<MY QUESTION>":
|
||||
question = st.text_input("Enter your question here:", "")
|
||||
|
||||
@@ -48,7 +48,11 @@ def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_ki
|
||||
yield passage
|
||||
|
||||
# create the ES index
|
||||
for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):
|
||||
for ok, action in streaming_bulk(
|
||||
client=es_client,
|
||||
index=index_name,
|
||||
actions=passage_generator(),
|
||||
):
|
||||
progress.update(1)
|
||||
successes += ok
|
||||
print("Indexed %d documents" % (successes,))
|
||||
@@ -137,7 +141,11 @@ class RetrievalQAEmbedder(torch.nn.Module):
|
||||
|
||||
# define function for checkpointing
|
||||
def partial_encode(*inputs):
|
||||
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
|
||||
encoder_outputs = self.sent_encoder.encoder(
|
||||
inputs[0],
|
||||
attention_mask=inputs[1],
|
||||
head_mask=head_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.sent_encoder.pooler(sequence_output)
|
||||
return pooled_output
|
||||
@@ -234,7 +242,11 @@ def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, ar
|
||||
if step % args.print_freq == 0 or step == 1:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
e,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
@@ -273,7 +285,11 @@ def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, sc
|
||||
if step % args.print_freq == 0:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
e,
|
||||
step,
|
||||
len(dataset_list[0]) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
@@ -354,7 +370,8 @@ class ELI5DatasetS2S(Dataset):
|
||||
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
|
||||
document = self.document_cache[q_id]
|
||||
in_st = "question: {} context: {}".format(
|
||||
question.lower().replace(" --t--", "").strip(), document.lower().strip(),
|
||||
question.lower().replace(" --t--", "").strip(),
|
||||
document.lower().strip(),
|
||||
)
|
||||
out_st = answer
|
||||
return (in_st, out_st)
|
||||
@@ -427,7 +444,11 @@ def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=
|
||||
if step % args.print_freq == 0 or step == 1:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
e,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
@@ -456,10 +477,18 @@ def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
|
||||
if step % args.print_freq == 0:
|
||||
print(
|
||||
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))
|
||||
print(
|
||||
"Total \t L: {:.3f} \t -- {:.3f}".format(
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
|
||||
@@ -506,7 +535,12 @@ def qa_s2s_generate(
|
||||
max_input_length=512,
|
||||
device="cuda:0",
|
||||
):
|
||||
model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
|
||||
model_inputs = make_qa_s2s_batch(
|
||||
[(question_doc, "A")],
|
||||
qa_s2s_tokenizer,
|
||||
max_input_length,
|
||||
device=device,
|
||||
)
|
||||
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
|
||||
generated_ids = qa_s2s_model.generate(
|
||||
input_ids=model_inputs["input_ids"],
|
||||
|
||||
Reference in New Issue
Block a user