Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -247,9 +247,12 @@ class Trainer:
|
||||
lr = self.scheduler_fn(state_step - 1)
|
||||
|
||||
eval_loss = self.evaluate(state, val_dataset)
|
||||
logging_dict = dict(
|
||||
step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()
|
||||
)
|
||||
logging_dict = {
|
||||
"step": state_step.item(),
|
||||
"eval_loss": eval_loss.item(),
|
||||
"tr_loss": tr_loss,
|
||||
"lr": lr.item(),
|
||||
}
|
||||
tqdm.write(str(logging_dict))
|
||||
self.logger.log(logging_dict, commit=True)
|
||||
|
||||
|
||||
@@ -144,9 +144,9 @@ def main():
|
||||
predictions = expand_to_aliases(example["output"])
|
||||
|
||||
# some preprocessing to both prediction and answer
|
||||
answers = set(["".join(a.split()) for a in answers])
|
||||
predictions = set(["".join(p.split()) for p in predictions])
|
||||
predictions = set([s for s in predictions if s not in ["``", "''", "`", "'"]])
|
||||
answers = {"".join(a.split()) for a in answers}
|
||||
predictions = {"".join(p.split()) for p in predictions}
|
||||
predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]}
|
||||
|
||||
# if there is a common element, it's a exact match
|
||||
example["match"] = len(list(answers & predictions)) > 0
|
||||
|
||||
@@ -314,12 +314,12 @@ if __name__ == "__main__":
|
||||
|
||||
data = data["train" if PROCESS_TRAIN == "true" else "validation"]
|
||||
|
||||
fn_kwargs = dict(
|
||||
tokenizer=tokenizer,
|
||||
doc_stride=DOC_STRIDE,
|
||||
max_length=MAX_LENGTH,
|
||||
assertion=False,
|
||||
)
|
||||
fn_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"doc_stride": DOC_STRIDE,
|
||||
"max_length": MAX_LENGTH,
|
||||
"assertion": False,
|
||||
}
|
||||
data = data.map(prepare_inputs, fn_kwargs=fn_kwargs)
|
||||
data = data.remove_columns(["annotations", "document", "id", "question"])
|
||||
print(data)
|
||||
|
||||
@@ -34,7 +34,7 @@ empty_dict = object()
|
||||
def _match(qs, ks):
|
||||
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
||||
# compile regexes and force complete match
|
||||
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
||||
qts = tuple((re.compile(x + "$") for x in qs))
|
||||
for i in range(len(ks) - len(qs) + 1):
|
||||
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
|
||||
if matches and all(matches):
|
||||
|
||||
Reference in New Issue
Block a user