Black 20 release
This commit is contained in:
@@ -48,7 +48,11 @@ def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_ki
|
||||
yield passage
|
||||
|
||||
# create the ES index
|
||||
for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):
|
||||
for ok, action in streaming_bulk(
|
||||
client=es_client,
|
||||
index=index_name,
|
||||
actions=passage_generator(),
|
||||
):
|
||||
progress.update(1)
|
||||
successes += ok
|
||||
print("Indexed %d documents" % (successes,))
|
||||
@@ -137,7 +141,11 @@ class RetrievalQAEmbedder(torch.nn.Module):
|
||||
|
||||
# define function for checkpointing
|
||||
def partial_encode(*inputs):
|
||||
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
|
||||
encoder_outputs = self.sent_encoder.encoder(
|
||||
inputs[0],
|
||||
attention_mask=inputs[1],
|
||||
head_mask=head_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.sent_encoder.pooler(sequence_output)
|
||||
return pooled_output
|
||||
@@ -234,7 +242,11 @@ def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, ar
|
||||
if step % args.print_freq == 0 or step == 1:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
e,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
@@ -273,7 +285,11 @@ def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, sc
|
||||
if step % args.print_freq == 0:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
e,
|
||||
step,
|
||||
len(dataset_list[0]) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
@@ -354,7 +370,8 @@ class ELI5DatasetS2S(Dataset):
|
||||
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
|
||||
document = self.document_cache[q_id]
|
||||
in_st = "question: {} context: {}".format(
|
||||
question.lower().replace(" --t--", "").strip(), document.lower().strip(),
|
||||
question.lower().replace(" --t--", "").strip(),
|
||||
document.lower().strip(),
|
||||
)
|
||||
out_st = answer
|
||||
return (in_st, out_st)
|
||||
@@ -427,7 +444,11 @@ def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=
|
||||
if step % args.print_freq == 0 or step == 1:
|
||||
print(
|
||||
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
e,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
loc_loss = 0
|
||||
@@ -456,10 +477,18 @@ def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
|
||||
if step % args.print_freq == 0:
|
||||
print(
|
||||
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
|
||||
step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
|
||||
step,
|
||||
len(dataset) // args.batch_size,
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))
|
||||
print(
|
||||
"Total \t L: {:.3f} \t -- {:.3f}".format(
|
||||
loc_loss / loc_steps,
|
||||
time() - st_time,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
|
||||
@@ -506,7 +535,12 @@ def qa_s2s_generate(
|
||||
max_input_length=512,
|
||||
device="cuda:0",
|
||||
):
|
||||
model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
|
||||
model_inputs = make_qa_s2s_batch(
|
||||
[(question_doc, "A")],
|
||||
qa_s2s_tokenizer,
|
||||
max_input_length,
|
||||
device=device,
|
||||
)
|
||||
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
|
||||
generated_ids = qa_s2s_model.generate(
|
||||
input_ids=model_inputs["input_ids"],
|
||||
|
||||
Reference in New Issue
Block a user