Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1
)
hits = response["hits"]["hits"]
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
res_list = [{k: hit["_source"][k] for k in hit["_source"] if k != "passage_text"} for hit in hits]
for r, hit in zip(res_list, hits):
r["passage_id"] = hit["_id"]
r["score"] = hit["_score"]
@@ -601,7 +601,7 @@ def make_qa_dense_index(
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
n_batches = math.ceil(passages_dset.num_rows / batch_size)
for i in range(n_batches):
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
passages = list(passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"])
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
fp[i * batch_size : (i + 1) * batch_size] = reps
if i % 50 == 0:
@@ -634,7 +634,7 @@ def query_qa_dense_index(
D, I = wiki_index.search(q_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc in zip(res_list, D[0]):
r["score"] = float(sc)
@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
]
all_res_lists = []
for res_passages, dl in zip(res_passages_lst, D):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
for r, sc in zip(res_list, dl):
r["score"] = float(sc)
all_res_lists += [res_list[:]]
@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki
D, I = wiki_index.search(a_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc, i in zip(res_list, D[0], I[0]):
r["passage_id"] = int(i)
@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
]
all_res_lists = []
for res_passages, dl, il in zip(res_passages_lst, D, I):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
for r, sc, i in zip(res_list, dl, il):
r["passage_id"] = int(i)
r["score"] = float(sc)