Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -192,7 +192,7 @@ def main():
|
||||
|
||||
# Optionally, predict on dev set and write to output_dir
|
||||
if args.do_predict:
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
|
||||
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
|
||||
model = model.load_from_checkpoint(checkpoints[-1])
|
||||
return trainer.test(model)
|
||||
|
||||
|
||||
@@ -211,6 +211,6 @@ if __name__ == "__main__":
|
||||
# pl use this default format to create a checkpoint:
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
|
||||
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
|
||||
model = model.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
||||
|
||||
@@ -810,10 +810,10 @@ def main():
|
||||
logger.info("Loading checkpoints saved during training for evaluation")
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
checkpoints = [
|
||||
os.path.dirname(c)
|
||||
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
]
|
||||
|
||||
else:
|
||||
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
||||
@@ -830,7 +830,7 @@ def main():
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
@@ -189,7 +189,7 @@ def main():
|
||||
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||
elif isinstance(obj, int):
|
||||
return obj
|
||||
return list(tokenize_and_encode(o) for o in obj)
|
||||
return [tokenize_and_encode(o) for o in obj]
|
||||
|
||||
logger.info("Encoding dataset...")
|
||||
train_dataset = load_rocstories_dataset(args.train_dataset)
|
||||
|
||||
@@ -696,9 +696,9 @@ def main():
|
||||
checkpoints = [args.model_name_or_path]
|
||||
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
checkpoints = [
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
]
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
@@ -712,7 +712,7 @@ def main():
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
@@ -111,7 +111,7 @@ def eval_data_dir(
|
||||
if num_return_sequences > 1:
|
||||
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
|
||||
for i, pred in enumerate(preds):
|
||||
results.append(dict(pred=pred, id=ids[i].item()))
|
||||
results.append({"pred": pred, "id": ids[i].item()})
|
||||
save_json(results, save_path)
|
||||
return results, sampler.num_replicas
|
||||
|
||||
@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
|
||||
records = []
|
||||
for partial_result in partial_results:
|
||||
records.extend(partial_result)
|
||||
records = list(sorted(records, key=lambda x: x["id"]))
|
||||
records = sorted(records, key=lambda x: x["id"])
|
||||
preds = [x["pred"] for x in records]
|
||||
return preds
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ def generate_summaries_or_translations(
|
||||
fout.close()
|
||||
runtime = int(time.time() - start_time) # seconds
|
||||
n_obs = len(examples)
|
||||
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
|
||||
return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}
|
||||
|
||||
|
||||
def datetime_now():
|
||||
|
||||
@@ -36,7 +36,7 @@ def parse_search_arg(search):
|
||||
groups = search.split()
|
||||
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
|
||||
entry_names = list(entries.keys())
|
||||
sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
|
||||
sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
|
||||
matrix = [list(x) for x in itertools.product(*sets)]
|
||||
return matrix, entry_names
|
||||
|
||||
|
||||
@@ -456,7 +456,7 @@ def pickle_save(obj, path):
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
return list(itertools.chain.from_iterable(summary_ids))
|
||||
|
||||
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user