Use Python 3.9 syntax in examples (#37279)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -18,9 +17,10 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
@@ -200,7 +200,7 @@ class Evaluator:
|
||||
def reset_metric(self):
|
||||
self.metric.reset()
|
||||
|
||||
def postprocess_target_batch(self, target_batch) -> List[Dict[str, torch.Tensor]]:
|
||||
def postprocess_target_batch(self, target_batch) -> list[dict[str, torch.Tensor]]:
|
||||
"""Collect targets in a form of list of dictionaries with keys "masks", "labels"."""
|
||||
batch_masks = target_batch[0]
|
||||
batch_labels = target_batch[1]
|
||||
@@ -214,13 +214,13 @@ class Evaluator:
|
||||
)
|
||||
return post_processed_targets
|
||||
|
||||
def get_target_sizes(self, post_processed_targets) -> List[List[int]]:
|
||||
def get_target_sizes(self, post_processed_targets) -> list[list[int]]:
|
||||
target_sizes = []
|
||||
for target in post_processed_targets:
|
||||
target_sizes.append(target["masks"].shape[-2:])
|
||||
return target_sizes
|
||||
|
||||
def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> List[Dict[str, torch.Tensor]]:
|
||||
def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> list[dict[str, torch.Tensor]]:
|
||||
"""Collect predictions in a form of list of dictionaries with keys "masks", "labels", "scores"."""
|
||||
|
||||
model_output = ModelOutput(class_queries_logits=prediction_batch[0], masks_queries_logits=prediction_batch[1])
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -21,9 +20,10 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
from typing import Any
|
||||
|
||||
import albumentations as A
|
||||
import datasets
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -551,7 +550,7 @@ def main():
|
||||
covariance_matrix=1e-5 * sigma,
|
||||
)
|
||||
new_token_embeddings = torch.stack(
|
||||
tuple((dist.sample() for _ in range(len(special_tokens)))),
|
||||
tuple(dist.sample() for _ in range(len(special_tokens))),
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
@@ -571,7 +570,7 @@ def main():
|
||||
covariance_matrix=1e-5 * sigma,
|
||||
)
|
||||
new_token_embeddings = torch.stack(
|
||||
tuple((dist.sample() for _ in range(len(special_tokens)))),
|
||||
tuple(dist.sample() for _ in range(len(special_tokens))),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -518,7 +517,7 @@ def main():
|
||||
covariance_matrix=1e-5 * sigma,
|
||||
)
|
||||
new_token_embeddings = torch.stack(
|
||||
tuple((dist.sample() for _ in range(len(special_tokens)))),
|
||||
tuple(dist.sample() for _ in range(len(special_tokens))),
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
@@ -538,7 +537,7 @@ def main():
|
||||
covariance_matrix=1e-5 * sigma,
|
||||
)
|
||||
new_token_embeddings = torch.stack(
|
||||
tuple((dist.sample() for _ in range(len(special_tokens)))),
|
||||
tuple(dist.sample() for _ in range(len(special_tokens))),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -18,9 +17,10 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, List, Mapping, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
@@ -60,7 +60,7 @@ class ModelOutput:
|
||||
|
||||
|
||||
def format_image_annotations_as_coco(
|
||||
image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]]
|
||||
image_id: str, categories: list[int], areas: list[float], bboxes: list[tuple[float]]
|
||||
) -> dict:
|
||||
"""Format one set of image annotations to the COCO format
|
||||
|
||||
@@ -94,7 +94,7 @@ def format_image_annotations_as_coco(
|
||||
}
|
||||
|
||||
|
||||
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
|
||||
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
|
||||
to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
|
||||
@@ -148,7 +148,7 @@ def augment_and_transform_batch(
|
||||
return result
|
||||
|
||||
|
||||
def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
|
||||
def collate_fn(batch: list[BatchFeature]) -> Mapping[str, Union[torch.Tensor, list[Any]]]:
|
||||
data = {}
|
||||
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
|
||||
data["labels"] = [x["labels"] for x in batch]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -19,9 +18,10 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Mapping, Tuple, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import albumentations as A
|
||||
import datasets
|
||||
@@ -61,7 +61,7 @@ require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/sema
|
||||
|
||||
# Copied from examples/pytorch/object-detection/run_object_detection.format_image_annotations_as_coco
|
||||
def format_image_annotations_as_coco(
|
||||
image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]]
|
||||
image_id: str, categories: list[int], areas: list[float], bboxes: list[tuple[float]]
|
||||
) -> dict:
|
||||
"""Format one set of image annotations to the COCO format
|
||||
|
||||
@@ -96,7 +96,7 @@ def format_image_annotations_as_coco(
|
||||
|
||||
|
||||
# Copied from examples/pytorch/object-detection/run_object_detection.convert_bbox_yolo_to_pascal
|
||||
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
|
||||
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
|
||||
to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
|
||||
@@ -152,7 +152,7 @@ def augment_and_transform_batch(
|
||||
|
||||
|
||||
# Copied from examples/pytorch/object-detection/run_object_detection.collate_fn
|
||||
def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
|
||||
def collate_fn(batch: list[BatchFeature]) -> Mapping[str, Union[torch.Tensor, list[Any]]]:
|
||||
data = {}
|
||||
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
|
||||
data["labels"] = [x["labels"] for x in batch]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 HuggingFace Inc..
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -33,7 +32,7 @@ def get_results(output_dir):
|
||||
results = {}
|
||||
path = os.path.join(output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"can't find {path}")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -22,7 +21,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@@ -469,7 +468,7 @@ def main():
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
) -> tuple[list[str], list[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -18,7 +17,7 @@ A subclass of `Trainer` specific to Question-Answering tasks
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -42,10 +41,10 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
eval_examples=None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
**gen_kwargs,
|
||||
) -> Dict[str, float]:
|
||||
) -> dict[str, float]:
|
||||
gen_kwargs = gen_kwargs.copy()
|
||||
|
||||
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -20,7 +19,7 @@ import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
@@ -32,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
def postprocess_qa_predictions(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
predictions: tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
@@ -223,7 +222,7 @@ def postprocess_qa_predictions(
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
if not os.path.isdir(output_dir):
|
||||
raise EnvironmentError(f"{output_dir} is not a directory.")
|
||||
raise OSError(f"{output_dir} is not a directory.")
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
@@ -253,7 +252,7 @@ def postprocess_qa_predictions(
|
||||
def postprocess_qa_predictions_with_beam_search(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
predictions: tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
@@ -417,7 +416,7 @@ def postprocess_qa_predictions_with_beam_search(
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
if not os.path.isdir(output_dir):
|
||||
raise EnvironmentError(f"{output_dir} is not a directory.")
|
||||
raise OSError(f"{output_dir} is not a directory.")
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -258,7 +257,7 @@ def main():
|
||||
else:
|
||||
repo_id = data_args.dataset_name
|
||||
filename = "id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset")))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: str(k) for k, v in id2label.items()}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -316,7 +315,7 @@ def main():
|
||||
else:
|
||||
repo_id = args.dataset_name
|
||||
filename = "id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset")))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -20,7 +19,7 @@ import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
@@ -328,7 +327,7 @@ class DataCollatorForWav2Vec2Pretraining:
|
||||
mask_time_prob: Optional[float] = 0.65
|
||||
mask_time_length: Optional[int] = 10
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
|
||||
# reformat list to dict and set to pytorch format
|
||||
batch = self.feature_extractor.pad(
|
||||
features,
|
||||
@@ -716,7 +715,7 @@ def main():
|
||||
}
|
||||
log_str = ""
|
||||
for k, v in train_logs.items():
|
||||
log_str += "| {}: {:.3e}".format(k, v.item())
|
||||
log_str += f"| {k}: {v.item():.3e}"
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
progress_bar.write(log_str)
|
||||
@@ -773,7 +772,7 @@ def main():
|
||||
|
||||
log_str = ""
|
||||
for k, v in val_logs.items():
|
||||
log_str += "| {}: {:.3e}".format(k, v.item())
|
||||
log_str += f"| {k}: {v.item():.3e}"
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
progress_bar.write(log_str)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -24,7 +23,7 @@ import re
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@@ -211,11 +210,11 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
chars_to_ignore: Optional[List[str]] = list_field(
|
||||
chars_to_ignore: Optional[list[str]] = list_field(
|
||||
default=None,
|
||||
metadata={"help": "A list of characters to remove from the transcripts."},
|
||||
)
|
||||
eval_metrics: List[str] = list_field(
|
||||
eval_metrics: list[str] = list_field(
|
||||
default=["wer"],
|
||||
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
|
||||
)
|
||||
@@ -318,7 +317,7 @@ class DataCollatorCTCWithPadding:
|
||||
pad_to_multiple_of_labels: Optional[int] = None
|
||||
feature_extractor_input_name: Optional[str] = "input_values"
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lengths and need
|
||||
# different padding methods
|
||||
input_features = [
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -24,7 +23,7 @@ import re
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@@ -201,11 +200,11 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
chars_to_ignore: Optional[List[str]] = list_field(
|
||||
chars_to_ignore: Optional[list[str]] = list_field(
|
||||
default=None,
|
||||
metadata={"help": "A list of characters to remove from the transcripts."},
|
||||
)
|
||||
eval_metrics: List[str] = list_field(
|
||||
eval_metrics: list[str] = list_field(
|
||||
default=["wer"],
|
||||
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
|
||||
)
|
||||
@@ -300,7 +299,7 @@ class DataCollatorCTCWithPadding:
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
pad_to_multiple_of_labels: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lengths and need
|
||||
# different padding methods
|
||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -23,7 +22,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@@ -110,11 +109,11 @@ class ModelArguments:
|
||||
freeze_encoder: bool = field(
|
||||
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
|
||||
)
|
||||
forced_decoder_ids: List[List[int]] = field(
|
||||
forced_decoder_ids: list[list[int]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
|
||||
)
|
||||
suppress_tokens: List[int] = field(
|
||||
suppress_tokens: list[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -247,7 +246,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
||||
decoder_start_token_id: int
|
||||
forward_attention_mask: bool
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lengths and need
|
||||
# different padding methods
|
||||
model_input_name = self.processor.model_input_names[0]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 HuggingFace Inc..
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -51,7 +50,7 @@ def get_results(output_dir):
|
||||
results = {}
|
||||
path = os.path.join(output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"can't find {path}")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 HuggingFace Inc..
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -87,7 +86,7 @@ def get_results(output_dir):
|
||||
results = {}
|
||||
path = os.path.join(output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"can't find {path}")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -21,7 +20,7 @@ import os
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@@ -256,7 +255,7 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
|
||||
def get_label_list(raw_dataset, split="train") -> List[str]:
|
||||
def get_label_list(raw_dataset, split="train") -> list[str]:
|
||||
"""Get the list of labels from a multi-label dataset"""
|
||||
|
||||
if isinstance(raw_dataset[split]["label"][0], list):
|
||||
@@ -537,7 +536,7 @@ def main():
|
||||
model.config.id2label = {id: label for label, id in label_to_id.items()}
|
||||
elif not is_regression: # classification, but not training
|
||||
logger.info("using label infos in the model config")
|
||||
logger.info("label2id: {}".format(model.config.label2id))
|
||||
logger.info(f"label2id: {model.config.label2id}")
|
||||
label_to_id = model.config.label2id
|
||||
else: # regression
|
||||
label_to_id = None
|
||||
@@ -549,7 +548,7 @@ def main():
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
def multi_labels_to_ids(labels: List[str]) -> List[float]:
|
||||
def multi_labels_to_ids(labels: list[str]) -> list[float]:
|
||||
ids = [0.0] * len(label_to_id) # BCELoss requires float as target type
|
||||
for label in labels:
|
||||
ids[label_to_id[label]] = 1.0
|
||||
@@ -735,7 +734,7 @@ def main():
|
||||
else:
|
||||
item = label_list[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
logger.info("Predict results saved at {}".format(output_predict_file))
|
||||
logger.info(f"Predict results saved at {output_predict_file}")
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
|
||||
|
||||
if training_args.push_to_hub:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@@ -19,7 +18,6 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
@@ -271,8 +269,8 @@ class _ModelFallbackWrapper(GenerationMixin):
|
||||
)
|
||||
|
||||
def _reorder_cache(
|
||||
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor]]:
|
||||
self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> tuple[tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
|
||||
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
Reference in New Issue
Block a user