Use Python 3.9 syntax in examples (#37279)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Huggingface
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import json
|
||||
import unittest
|
||||
|
||||
@@ -25,7 +23,7 @@ from utils import calculate_bleu
|
||||
|
||||
|
||||
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
|
||||
with io.open(filename, "r", encoding="utf-8") as f:
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
bleu_data = json.load(f)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import time
|
||||
from json import JSONDecodeError
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -55,10 +54,10 @@ def eval_data_dir(
|
||||
task="summarization",
|
||||
local_rank=None,
|
||||
num_return_sequences=1,
|
||||
dataset_kwargs: Dict = None,
|
||||
dataset_kwargs: dict = None,
|
||||
prefix="",
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
||||
model_name = str(model_name)
|
||||
assert local_rank is not None
|
||||
@@ -211,7 +210,7 @@ def run_generate():
|
||||
calc_bleu = "translation" in args.task
|
||||
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
||||
metric_name = "bleu" if calc_bleu else "rouge"
|
||||
metrics: Dict = score_fn(preds, labels)
|
||||
metrics: dict = score_fn(preds, labels)
|
||||
metrics["n_obs"] = len(preds)
|
||||
runtime = time.time() - start_time
|
||||
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4)
|
||||
@@ -227,7 +226,7 @@ def run_generate():
|
||||
shutil.rmtree(json_save_dir)
|
||||
|
||||
|
||||
def combine_partial_results(partial_results) -> List:
|
||||
def combine_partial_results(partial_results) -> list:
|
||||
"""Concatenate partial results into one file, then sort it by id."""
|
||||
records = []
|
||||
for partial_result in partial_results:
|
||||
@@ -237,7 +236,7 @@ def combine_partial_results(partial_results) -> List:
|
||||
return preds
|
||||
|
||||
|
||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
|
||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> list[dict[str, list]]:
|
||||
# WAIT FOR lots of .json files
|
||||
start_wait = time.time()
|
||||
logger.info("waiting for all nodes to finish")
|
||||
|
||||
@@ -20,7 +20,6 @@ import time
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@@ -36,7 +35,7 @@ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def generate_summaries_or_translations(
|
||||
examples: List[str],
|
||||
examples: list[str],
|
||||
out_file: str,
|
||||
model_name: str,
|
||||
batch_size: int = 8,
|
||||
@@ -45,7 +44,7 @@ def generate_summaries_or_translations(
|
||||
task="summarization",
|
||||
prefix=None,
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Save model.generate results to <out_file>, and return how long it took."""
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model_name = str(model_name)
|
||||
|
||||
@@ -34,7 +34,7 @@ task_score_names = {
|
||||
|
||||
def parse_search_arg(search):
|
||||
groups = search.split()
|
||||
entries = dict((g.split("=") for g in groups))
|
||||
entries = dict(g.split("=") for g in groups)
|
||||
entry_names = list(entries.keys())
|
||||
sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
|
||||
matrix = [list(x) for x in itertools.product(*sets)]
|
||||
@@ -105,7 +105,7 @@ def run_search():
|
||||
col_widths = {col: len(str(col)) for col in col_names}
|
||||
results = []
|
||||
for r in matrix:
|
||||
hparams = dict((x.replace("--", "").split() for x in r))
|
||||
hparams = dict(x.replace("--", "").split() for x in r)
|
||||
args_exp = " ".join(r).split()
|
||||
args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM
|
||||
sys.argv = args_normal + args_exp
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -172,10 +172,10 @@ class Seq2SeqTrainer(Trainer):
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import io
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
@@ -29,5 +28,5 @@ def get_all_data(pairs, n_objs):
|
||||
|
||||
text = get_all_data(pairs, n_objs)
|
||||
filename = "./fsmt_val_data.json"
|
||||
with io.open(filename, "w", encoding="utf-8") as f:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
bleu_data = json.dump(text, f, indent=2, ensure_ascii=False)
|
||||
|
||||
@@ -19,9 +19,10 @@ import math
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
from collections.abc import Iterable
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
@@ -67,7 +68,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
return loss, nll_loss
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
def lmap(f: Callable, x: Iterable) -> list:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
@@ -77,11 +78,11 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
||||
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
||||
|
||||
|
||||
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
|
||||
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
def decode_pred(pred: EvalPrediction) -> tuple[list[str], list[str]]:
|
||||
pred_ids = pred.predictions
|
||||
label_ids = pred.label_ids
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
@@ -91,16 +92,16 @@ def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) ->
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
def summarization_metrics(pred: EvalPrediction) -> dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
rouge: dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
def translation_metrics(pred: EvalPrediction) -> dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
bleu: dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
@@ -183,7 +184,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
return min(self.src_lens[i], self.max_target_length)
|
||||
|
||||
# call fairseq cython function
|
||||
batch_sampler: List[List[int]] = batch_by_size(
|
||||
batch_sampler: list[list[int]] = batch_by_size(
|
||||
sorted_indices,
|
||||
num_tokens_fn=num_tokens_in_example,
|
||||
max_tokens=max_tokens_per_batch,
|
||||
@@ -207,7 +208,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
|
||||
|
||||
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
def __getitem__(self, index) -> dict[str, torch.Tensor]:
|
||||
"""Call tokenizer on src and tgt_lines"""
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
@@ -237,7 +238,7 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def collate_fn(self, batch) -> dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["labels"] for x in batch])
|
||||
@@ -255,7 +256,7 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
"""A dataset that calls prepare_seq2seq_batch."""
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, str]:
|
||||
def __getitem__(self, index) -> dict[str, str]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
@@ -263,9 +264,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def collate_fn(self, batch) -> dict[str, torch.Tensor]:
|
||||
"""Call prepare_seq2seq_batch."""
|
||||
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||
batch_encoding: dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
max_length=self.max_source_length,
|
||||
@@ -293,7 +294,7 @@ class Seq2SeqDataCollator:
|
||||
if data_args.tgt_lang is not None:
|
||||
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, batch) -> dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
batch = self._encode(batch)
|
||||
input_ids, attention_mask, labels = (
|
||||
@@ -329,7 +330,7 @@ class Seq2SeqDataCollator:
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def _encode(self, batch) -> dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
@@ -355,7 +356,7 @@ class SortishSampler(Sampler):
|
||||
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
|
||||
|
||||
|
||||
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
|
||||
def sortish_sampler_indices(data: list, bs: int, shuffle=True) -> np.array:
|
||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||
if not shuffle:
|
||||
return np.argsort(np.array(data) * -1)
|
||||
@@ -455,7 +456,7 @@ def pickle_save(obj, path):
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
def flatten_list(summary_ids: list[list]):
|
||||
return list(itertools.chain.from_iterable(summary_ids))
|
||||
|
||||
|
||||
@@ -506,14 +507,14 @@ def extract_rouge_mid_statistics(dct):
|
||||
|
||||
|
||||
def calculate_rouge(
|
||||
pred_lns: List[str],
|
||||
tgt_lns: List[str],
|
||||
pred_lns: list[str],
|
||||
tgt_lns: list[str],
|
||||
use_stemmer=True,
|
||||
rouge_keys=ROUGE_KEYS,
|
||||
return_precision_and_recall=False,
|
||||
bootstrap_aggregation=True,
|
||||
newline_sep=True,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Calculate rouge using rouge_scorer package.
|
||||
|
||||
Args:
|
||||
@@ -590,19 +591,19 @@ def any_requires_grad(model: nn.Module) -> bool:
|
||||
|
||||
|
||||
def assert_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
model_grads: list[bool] = list(grad_status(model))
|
||||
n_require_grad = sum(lmap(int, model_grads))
|
||||
npars = len(model_grads)
|
||||
assert not any(model_grads), f"{n_require_grad / npars:.1%} of {npars} weights require grad"
|
||||
|
||||
|
||||
def assert_not_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
model_grads: list[bool] = list(grad_status(model))
|
||||
npars = len(model_grads)
|
||||
assert any(model_grads), f"none of {npars} weights require grad"
|
||||
|
||||
|
||||
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
||||
def parse_numeric_n_bool_cl_kwargs(unparsed_args: list[str]) -> dict[str, Union[int, float, bool]]:
|
||||
"""
|
||||
Parse an argv list of unspecified command line args to a dict.
|
||||
Assumes all values are either numeric or boolean in the form of true/false.
|
||||
|
||||
Reference in New Issue
Block a user