Use Python 3.9 syntax in examples (#37279)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -19,9 +19,10 @@ import math
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
from collections.abc import Iterable
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
@@ -67,7 +68,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
return loss, nll_loss
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
def lmap(f: Callable, x: Iterable) -> list:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
@@ -77,11 +78,11 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
||||
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
||||
|
||||
|
||||
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
|
||||
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
def decode_pred(pred: EvalPrediction) -> tuple[list[str], list[str]]:
|
||||
pred_ids = pred.predictions
|
||||
label_ids = pred.label_ids
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
@@ -91,16 +92,16 @@ def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) ->
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
def summarization_metrics(pred: EvalPrediction) -> dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
rouge: dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
def translation_metrics(pred: EvalPrediction) -> dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
bleu: dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
@@ -183,7 +184,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
return min(self.src_lens[i], self.max_target_length)
|
||||
|
||||
# call fairseq cython function
|
||||
batch_sampler: List[List[int]] = batch_by_size(
|
||||
batch_sampler: list[list[int]] = batch_by_size(
|
||||
sorted_indices,
|
||||
num_tokens_fn=num_tokens_in_example,
|
||||
max_tokens=max_tokens_per_batch,
|
||||
@@ -207,7 +208,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
|
||||
|
||||
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
def __getitem__(self, index) -> dict[str, torch.Tensor]:
|
||||
"""Call tokenizer on src and tgt_lines"""
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
@@ -237,7 +238,7 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def collate_fn(self, batch) -> dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["labels"] for x in batch])
|
||||
@@ -255,7 +256,7 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
"""A dataset that calls prepare_seq2seq_batch."""
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, str]:
|
||||
def __getitem__(self, index) -> dict[str, str]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
@@ -263,9 +264,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def collate_fn(self, batch) -> dict[str, torch.Tensor]:
|
||||
"""Call prepare_seq2seq_batch."""
|
||||
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||
batch_encoding: dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
max_length=self.max_source_length,
|
||||
@@ -293,7 +294,7 @@ class Seq2SeqDataCollator:
|
||||
if data_args.tgt_lang is not None:
|
||||
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, batch) -> dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
batch = self._encode(batch)
|
||||
input_ids, attention_mask, labels = (
|
||||
@@ -329,7 +330,7 @@ class Seq2SeqDataCollator:
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
def _encode(self, batch) -> dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
@@ -355,7 +356,7 @@ class SortishSampler(Sampler):
|
||||
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
|
||||
|
||||
|
||||
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
|
||||
def sortish_sampler_indices(data: list, bs: int, shuffle=True) -> np.array:
|
||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||
if not shuffle:
|
||||
return np.argsort(np.array(data) * -1)
|
||||
@@ -455,7 +456,7 @@ def pickle_save(obj, path):
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
def flatten_list(summary_ids: list[list]):
|
||||
return list(itertools.chain.from_iterable(summary_ids))
|
||||
|
||||
|
||||
@@ -506,14 +507,14 @@ def extract_rouge_mid_statistics(dct):
|
||||
|
||||
|
||||
def calculate_rouge(
|
||||
pred_lns: List[str],
|
||||
tgt_lns: List[str],
|
||||
pred_lns: list[str],
|
||||
tgt_lns: list[str],
|
||||
use_stemmer=True,
|
||||
rouge_keys=ROUGE_KEYS,
|
||||
return_precision_and_recall=False,
|
||||
bootstrap_aggregation=True,
|
||||
newline_sep=True,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Calculate rouge using rouge_scorer package.
|
||||
|
||||
Args:
|
||||
@@ -590,19 +591,19 @@ def any_requires_grad(model: nn.Module) -> bool:
|
||||
|
||||
|
||||
def assert_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
model_grads: list[bool] = list(grad_status(model))
|
||||
n_require_grad = sum(lmap(int, model_grads))
|
||||
npars = len(model_grads)
|
||||
assert not any(model_grads), f"{n_require_grad / npars:.1%} of {npars} weights require grad"
|
||||
|
||||
|
||||
def assert_not_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
model_grads: list[bool] = list(grad_status(model))
|
||||
npars = len(model_grads)
|
||||
assert any(model_grads), f"none of {npars} weights require grad"
|
||||
|
||||
|
||||
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
||||
def parse_numeric_n_bool_cl_kwargs(unparsed_args: list[str]) -> dict[str, Union[int, float, bool]]:
|
||||
"""
|
||||
Parse an argv list of unspecified command line args to a dict.
|
||||
Assumes all values are either numeric or boolean in the form of true/false.
|
||||
|
||||
Reference in New Issue
Block a user