Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1])
return trainer.test(model)

View File

@@ -211,6 +211,6 @@ if __name__ == "__main__":
# pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model)

View File

@@ -810,10 +810,10 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c)
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
else:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
@@ -830,7 +830,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)
logger.info("Results: {}".format(results))

View File

@@ -189,7 +189,7 @@ def main():
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
elif isinstance(obj, int):
return obj
return list(tokenize_and_encode(o) for o in obj)
return [tokenize_and_encode(o) for o in obj]
logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset)

View File

@@ -696,9 +696,9 @@ def main():
checkpoints = [args.model_name_or_path]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
@@ -712,7 +712,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)
logger.info("Results: {}".format(results))

View File

@@ -111,7 +111,7 @@ def eval_data_dir(
if num_return_sequences > 1:
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item()))
results.append({"pred": pred, "id": ids[i].item()})
save_json(results, save_path)
return results, sampler.num_replicas
@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
records = sorted(records, key=lambda x: x["id"])
preds = [x["pred"] for x in records]
return preds

View File

@@ -76,7 +76,7 @@ def generate_summaries_or_translations(
fout.close()
runtime = int(time.time() - start_time) # seconds
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}
def datetime_now():

View File

@@ -36,7 +36,7 @@ def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names

View File

@@ -456,7 +456,7 @@ def pickle_save(obj, path):
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None: