Use Python 3.9 syntax in examples (#37279)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -19,7 +19,6 @@ import time
|
||||
from json import JSONDecodeError
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -55,10 +54,10 @@ def eval_data_dir(
|
||||
task="summarization",
|
||||
local_rank=None,
|
||||
num_return_sequences=1,
|
||||
dataset_kwargs: Dict = None,
|
||||
dataset_kwargs: dict = None,
|
||||
prefix="",
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
||||
model_name = str(model_name)
|
||||
assert local_rank is not None
|
||||
@@ -211,7 +210,7 @@ def run_generate():
|
||||
calc_bleu = "translation" in args.task
|
||||
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
||||
metric_name = "bleu" if calc_bleu else "rouge"
|
||||
metrics: Dict = score_fn(preds, labels)
|
||||
metrics: dict = score_fn(preds, labels)
|
||||
metrics["n_obs"] = len(preds)
|
||||
runtime = time.time() - start_time
|
||||
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4)
|
||||
@@ -227,7 +226,7 @@ def run_generate():
|
||||
shutil.rmtree(json_save_dir)
|
||||
|
||||
|
||||
def combine_partial_results(partial_results) -> List:
|
||||
def combine_partial_results(partial_results) -> list:
|
||||
"""Concatenate partial results into one file, then sort it by id."""
|
||||
records = []
|
||||
for partial_result in partial_results:
|
||||
@@ -237,7 +236,7 @@ def combine_partial_results(partial_results) -> List:
|
||||
return preds
|
||||
|
||||
|
||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
|
||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> list[dict[str, list]]:
|
||||
# WAIT FOR lots of .json files
|
||||
start_wait = time.time()
|
||||
logger.info("waiting for all nodes to finish")
|
||||
|
||||
Reference in New Issue
Block a user