Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1
|
||||
)
|
||||
hits = response["hits"]["hits"]
|
||||
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
|
||||
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
|
||||
res_list = [{k: hit["_source"][k] for k in hit["_source"] if k != "passage_text"} for hit in hits]
|
||||
for r, hit in zip(res_list, hits):
|
||||
r["passage_id"] = hit["_id"]
|
||||
r["score"] = hit["_score"]
|
||||
@@ -601,7 +601,7 @@ def make_qa_dense_index(
|
||||
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
|
||||
n_batches = math.ceil(passages_dset.num_rows / batch_size)
|
||||
for i in range(n_batches):
|
||||
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
|
||||
passages = list(passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"])
|
||||
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
|
||||
fp[i * batch_size : (i + 1) * batch_size] = reps
|
||||
if i % 50 == 0:
|
||||
@@ -634,7 +634,7 @@ def query_qa_dense_index(
|
||||
D, I = wiki_index.search(q_rep, 2 * n_results)
|
||||
res_passages = [wiki_passages[int(i)] for i in I[0]]
|
||||
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
|
||||
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||
for r, sc in zip(res_list, D[0]):
|
||||
r["score"] = float(sc)
|
||||
@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
|
||||
]
|
||||
all_res_lists = []
|
||||
for res_passages, dl in zip(res_passages_lst, D):
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
|
||||
for r, sc in zip(res_list, dl):
|
||||
r["score"] = float(sc)
|
||||
all_res_lists += [res_list[:]]
|
||||
@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki
|
||||
D, I = wiki_index.search(a_rep, 2 * n_results)
|
||||
res_passages = [wiki_passages[int(i)] for i in I[0]]
|
||||
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
|
||||
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
|
||||
for r, sc, i in zip(res_list, D[0], I[0]):
|
||||
r["passage_id"] = int(i)
|
||||
@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
|
||||
]
|
||||
all_res_lists = []
|
||||
for res_passages, dl, il in zip(res_passages_lst, D, I):
|
||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
|
||||
for r, sc, i in zip(res_list, dl, il):
|
||||
r["passage_id"] = int(i)
|
||||
r["score"] = float(sc)
|
||||
|
||||
Reference in New Issue
Block a user