Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -247,9 +247,12 @@ class Trainer:
lr = self.scheduler_fn(state_step - 1)
eval_loss = self.evaluate(state, val_dataset)
logging_dict = dict(
step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()
)
logging_dict = {
"step": state_step.item(),
"eval_loss": eval_loss.item(),
"tr_loss": tr_loss,
"lr": lr.item(),
}
tqdm.write(str(logging_dict))
self.logger.log(logging_dict, commit=True)

View File

@@ -144,9 +144,9 @@ def main():
predictions = expand_to_aliases(example["output"])
# some preprocessing to both prediction and answer
answers = set(["".join(a.split()) for a in answers])
predictions = set(["".join(p.split()) for p in predictions])
predictions = set([s for s in predictions if s not in ["``", "''", "`", "'"]])
answers = {"".join(a.split()) for a in answers}
predictions = {"".join(p.split()) for p in predictions}
predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]}
# if there is a common element, it's a exact match
example["match"] = len(list(answers & predictions)) > 0

View File

@@ -314,12 +314,12 @@ if __name__ == "__main__":
data = data["train" if PROCESS_TRAIN == "true" else "validation"]
fn_kwargs = dict(
tokenizer=tokenizer,
doc_stride=DOC_STRIDE,
max_length=MAX_LENGTH,
assertion=False,
)
fn_kwargs = {
"tokenizer": tokenizer,
"doc_stride": DOC_STRIDE,
"max_length": MAX_LENGTH,
"assertion": False,
}
data = data.map(prepare_inputs, fn_kwargs=fn_kwargs)
data = data.remove_columns(["annotations", "document", "id", "question"])
print(data)

View File

@@ -34,7 +34,7 @@ empty_dict = object()
def _match(qs, ks):
"""Return True if regexes in qs match any window of strings in tuple ks."""
# compile regexes and force complete match
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
qts = tuple((re.compile(x + "$") for x in qs))
for i in range(len(ks) - len(qs) + 1):
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
if matches and all(matches):