Use Python 3.9 syntax in examples (#37279)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-07 19:52:21 +08:00
committed by GitHub
parent 08f36771b3
commit 0fb8d49e88
123 changed files with 358 additions and 451 deletions

View File

@@ -19,9 +19,10 @@ import math
import os
import pickle
import socket
from collections.abc import Iterable
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Tuple, Union
from typing import Callable, Union
import git
import numpy as np
@@ -67,7 +68,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return loss, nll_loss
def lmap(f: Callable, x: Iterable) -> List:
def lmap(f: Callable, x: Iterable) -> list:
"""list(map(f, x))"""
return list(map(f, x))
@@ -77,11 +78,11 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
def decode_pred(pred: EvalPrediction) -> tuple[list[str], list[str]]:
pred_ids = pred.predictions
label_ids = pred.label_ids
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
@@ -91,16 +92,16 @@ def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) ->
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
def summarization_metrics(pred: EvalPrediction) -> dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
rouge: dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
def translation_metrics(pred: EvalPrediction) -> dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
bleu: dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
@@ -183,7 +184,7 @@ class AbstractSeq2SeqDataset(Dataset):
return min(self.src_lens[i], self.max_target_length)
# call fairseq cython function
batch_sampler: List[List[int]] = batch_by_size(
batch_sampler: list[list[int]] = batch_by_size(
sorted_indices,
num_tokens_fn=num_tokens_in_example,
max_tokens=max_tokens_per_batch,
@@ -207,7 +208,7 @@ class AbstractSeq2SeqDataset(Dataset):
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
def __getitem__(self, index) -> dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
@@ -237,7 +238,7 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
**self.dataset_kwargs,
)
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
def collate_fn(self, batch) -> dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["labels"] for x in batch])
@@ -255,7 +256,7 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""A dataset that calls prepare_seq2seq_batch."""
def __getitem__(self, index) -> Dict[str, str]:
def __getitem__(self, index) -> dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
@@ -263,9 +264,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
assert tgt_line, f"empty tgt line for index {index}"
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
def collate_fn(self, batch) -> dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
batch_encoding: dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.max_source_length,
@@ -293,7 +294,7 @@ class Seq2SeqDataCollator:
if data_args.tgt_lang is not None:
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
def __call__(self, batch) -> Dict[str, torch.Tensor]:
def __call__(self, batch) -> dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
@@ -329,7 +330,7 @@ class Seq2SeqDataCollator:
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
def _encode(self, batch) -> dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
@@ -355,7 +356,7 @@ class SortishSampler(Sampler):
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
def sortish_sampler_indices(data: list, bs: int, shuffle=True) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if not shuffle:
return np.argsort(np.array(data) * -1)
@@ -455,7 +456,7 @@ def pickle_save(obj, path):
return pickle.dump(obj, f)
def flatten_list(summary_ids: List[List]):
def flatten_list(summary_ids: list[list]):
return list(itertools.chain.from_iterable(summary_ids))
@@ -506,14 +507,14 @@ def extract_rouge_mid_statistics(dct):
def calculate_rouge(
pred_lns: List[str],
tgt_lns: List[str],
pred_lns: list[str],
tgt_lns: list[str],
use_stemmer=True,
rouge_keys=ROUGE_KEYS,
return_precision_and_recall=False,
bootstrap_aggregation=True,
newline_sep=True,
) -> Dict:
) -> dict:
"""Calculate rouge using rouge_scorer package.
Args:
@@ -590,19 +591,19 @@ def any_requires_grad(model: nn.Module) -> bool:
def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
model_grads: list[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(model_grads), f"{n_require_grad / npars:.1%} of {npars} weights require grad"
def assert_not_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
model_grads: list[bool] = list(grad_status(model))
npars = len(model_grads)
assert any(model_grads), f"none of {npars} weights require grad"
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
def parse_numeric_n_bool_cl_kwargs(unparsed_args: list[str]) -> dict[str, Union[int, float, bool]]:
"""
Parse an argv list of unspecified command line args to a dict.
Assumes all values are either numeric or boolean in the form of true/false.