Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -48,7 +48,11 @@ def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_ki
yield passage
# create the ES index
for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):
for ok, action in streaming_bulk(
client=es_client,
index=index_name,
actions=passage_generator(),
):
progress.update(1)
successes += ok
print("Indexed %d documents" % (successes,))
@@ -137,7 +141,11 @@ class RetrievalQAEmbedder(torch.nn.Module):
# define function for checkpointing
def partial_encode(*inputs):
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
encoder_outputs = self.sent_encoder.encoder(
inputs[0],
attention_mask=inputs[1],
head_mask=head_mask,
)
sequence_output = encoder_outputs[0]
pooled_output = self.sent_encoder.pooler(sequence_output)
return pooled_output
@@ -234,7 +242,11 @@ def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, ar
if step % args.print_freq == 0 or step == 1:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
e,
step,
len(dataset) // args.batch_size,
loc_loss / loc_steps,
time() - st_time,
)
)
loc_loss = 0
@@ -273,7 +285,11 @@ def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, sc
if step % args.print_freq == 0:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
e,
step,
len(dataset_list[0]) // args.batch_size,
loc_loss / loc_steps,
time() - st_time,
)
)
loc_loss = 0
@@ -354,7 +370,8 @@ class ELI5DatasetS2S(Dataset):
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
document = self.document_cache[q_id]
in_st = "question: {} context: {}".format(
question.lower().replace(" --t--", "").strip(), document.lower().strip(),
question.lower().replace(" --t--", "").strip(),
document.lower().strip(),
)
out_st = answer
return (in_st, out_st)
@@ -427,7 +444,11 @@ def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=
if step % args.print_freq == 0 or step == 1:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
e,
step,
len(dataset) // args.batch_size,
loc_loss / loc_steps,
time() - st_time,
)
)
loc_loss = 0
@@ -456,10 +477,18 @@ def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
if step % args.print_freq == 0:
print(
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
step,
len(dataset) // args.batch_size,
loc_loss / loc_steps,
time() - st_time,
)
)
print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))
print(
"Total \t L: {:.3f} \t -- {:.3f}".format(
loc_loss / loc_steps,
time() - st_time,
)
)
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
@@ -506,7 +535,12 @@ def qa_s2s_generate(
max_input_length=512,
device="cuda:0",
):
model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
model_inputs = make_qa_s2s_batch(
[(question_doc, "A")],
qa_s2s_tokenizer,
max_input_length,
device=device,
)
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
generated_ids = qa_s2s_model.generate(
input_ids=model_inputs["input_ids"],