Use Python 3.9 syntax in examples (#37279)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-07 19:52:21 +08:00
committed by GitHub
parent 08f36771b3
commit 0fb8d49e88
123 changed files with 358 additions and 451 deletions

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,9 +17,10 @@
import logging
import os
import sys
from collections.abc import Mapping
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Optional
import albumentations as A
import numpy as np
@@ -200,7 +200,7 @@ class Evaluator:
def reset_metric(self):
self.metric.reset()
def postprocess_target_batch(self, target_batch) -> List[Dict[str, torch.Tensor]]:
def postprocess_target_batch(self, target_batch) -> list[dict[str, torch.Tensor]]:
"""Collect targets in a form of list of dictionaries with keys "masks", "labels"."""
batch_masks = target_batch[0]
batch_labels = target_batch[1]
@@ -214,13 +214,13 @@ class Evaluator:
)
return post_processed_targets
def get_target_sizes(self, post_processed_targets) -> List[List[int]]:
def get_target_sizes(self, post_processed_targets) -> list[list[int]]:
target_sizes = []
for target in post_processed_targets:
target_sizes.append(target["masks"].shape[-2:])
return target_sizes
def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> List[Dict[str, torch.Tensor]]:
def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> list[dict[str, torch.Tensor]]:
"""Collect predictions in a form of list of dictionaries with keys "masks", "labels", "scores"."""
model_output = ModelOutput(class_queries_logits=prediction_batch[0], masks_queries_logits=prediction_batch[1])

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,9 +20,10 @@ import logging
import math
import os
import sys
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import Any, Mapping
from typing import Any
import albumentations as A
import datasets

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -551,7 +550,7 @@ def main():
covariance_matrix=1e-5 * sigma,
)
new_token_embeddings = torch.stack(
tuple((dist.sample() for _ in range(len(special_tokens)))),
tuple(dist.sample() for _ in range(len(special_tokens))),
dim=0,
)
else:
@@ -571,7 +570,7 @@ def main():
covariance_matrix=1e-5 * sigma,
)
new_token_embeddings = torch.stack(
tuple((dist.sample() for _ in range(len(special_tokens)))),
tuple(dist.sample() for _ in range(len(special_tokens))),
dim=0,
)

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -518,7 +517,7 @@ def main():
covariance_matrix=1e-5 * sigma,
)
new_token_embeddings = torch.stack(
tuple((dist.sample() for _ in range(len(special_tokens)))),
tuple(dist.sample() for _ in range(len(special_tokens))),
dim=0,
)
else:
@@ -538,7 +537,7 @@ def main():
covariance_matrix=1e-5 * sigma,
)
new_token_embeddings = torch.stack(
tuple((dist.sample() for _ in range(len(special_tokens)))),
tuple(dist.sample() for _ in range(len(special_tokens))),
dim=0,
)

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,9 +17,10 @@
import logging
import os
import sys
from collections.abc import Mapping
from dataclasses import dataclass, field
from functools import partial
from typing import Any, List, Mapping, Optional, Tuple, Union
from typing import Any, Optional, Union
import albumentations as A
import numpy as np
@@ -60,7 +60,7 @@ class ModelOutput:
def format_image_annotations_as_coco(
image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]]
image_id: str, categories: list[int], areas: list[float], bboxes: list[tuple[float]]
) -> dict:
"""Format one set of image annotations to the COCO format
@@ -94,7 +94,7 @@ def format_image_annotations_as_coco(
}
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
"""
Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
@@ -148,7 +148,7 @@ def augment_and_transform_batch(
return result
def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
def collate_fn(batch: list[BatchFeature]) -> Mapping[str, Union[torch.Tensor, list[Any]]]:
data = {}
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
data["labels"] = [x["labels"] for x in batch]

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,9 +18,10 @@ import json
import logging
import math
import os
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import Any, List, Mapping, Tuple, Union
from typing import Any, Union
import albumentations as A
import datasets
@@ -61,7 +61,7 @@ require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/sema
# Copied from examples/pytorch/object-detection/run_object_detection.format_image_annotations_as_coco
def format_image_annotations_as_coco(
image_id: str, categories: List[int], areas: List[float], bboxes: List[Tuple[float]]
image_id: str, categories: list[int], areas: list[float], bboxes: list[tuple[float]]
) -> dict:
"""Format one set of image annotations to the COCO format
@@ -96,7 +96,7 @@ def format_image_annotations_as_coco(
# Copied from examples/pytorch/object-detection/run_object_detection.convert_bbox_yolo_to_pascal
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
"""
Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
@@ -152,7 +152,7 @@ def augment_and_transform_batch(
# Copied from examples/pytorch/object-detection/run_object_detection.collate_fn
def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, List[Any]]]:
def collate_fn(batch: list[BatchFeature]) -> Mapping[str, Union[torch.Tensor, list[Any]]]:
data = {}
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
data["labels"] = [x["labels"] for x in batch]

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -33,7 +32,7 @@ def get_results(output_dir):
results = {}
path = os.path.join(output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
with open(path) as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,7 +21,7 @@ import logging
import os
import sys
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from typing import Optional
import datasets
import evaluate
@@ -469,7 +468,7 @@ def main():
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
) -> tuple[list[str], list[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +17,7 @@ A subclass of `Trainer` specific to Question-Answering tasks
import math
import time
from typing import Dict, List, Optional
from typing import Optional
from torch.utils.data import Dataset
@@ -42,10 +41,10 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
self,
eval_dataset: Optional[Dataset] = None,
eval_examples=None,
ignore_keys: Optional[List[str]] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs,
) -> Dict[str, float]:
) -> dict[str, float]:
gen_kwargs = gen_kwargs.copy()
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,7 +19,7 @@ import collections
import json
import logging
import os
from typing import Optional, Tuple
from typing import Optional
import numpy as np
from tqdm.auto import tqdm
@@ -32,7 +31,7 @@ logger = logging.getLogger(__name__)
def postprocess_qa_predictions(
examples,
features,
predictions: Tuple[np.ndarray, np.ndarray],
predictions: tuple[np.ndarray, np.ndarray],
version_2_with_negative: bool = False,
n_best_size: int = 20,
max_answer_length: int = 30,
@@ -223,7 +222,7 @@ def postprocess_qa_predictions(
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
if not os.path.isdir(output_dir):
raise EnvironmentError(f"{output_dir} is not a directory.")
raise OSError(f"{output_dir} is not a directory.")
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
@@ -253,7 +252,7 @@ def postprocess_qa_predictions(
def postprocess_qa_predictions_with_beam_search(
examples,
features,
predictions: Tuple[np.ndarray, np.ndarray],
predictions: tuple[np.ndarray, np.ndarray],
version_2_with_negative: bool = False,
n_best_size: int = 20,
max_answer_length: int = 30,
@@ -417,7 +416,7 @@ def postprocess_qa_predictions_with_beam_search(
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
if not os.path.isdir(output_dir):
raise EnvironmentError(f"{output_dir} is not a directory.")
raise OSError(f"{output_dir} is not a directory.")
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -258,7 +257,7 @@ def main():
else:
repo_id = data_args.dataset_name
filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset")))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: str(k) for k, v in id2label.items()}

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -316,7 +315,7 @@ def main():
else:
repo_id = args.dataset_name
filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset")))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,7 +19,7 @@ import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import datasets
import torch
@@ -328,7 +327,7 @@ class DataCollatorForWav2Vec2Pretraining:
mask_time_prob: Optional[float] = 0.65
mask_time_length: Optional[int] = 10
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
# reformat list to dict and set to pytorch format
batch = self.feature_extractor.pad(
features,
@@ -716,7 +715,7 @@ def main():
}
log_str = ""
for k, v in train_logs.items():
log_str += "| {}: {:.3e}".format(k, v.item())
log_str += f"| {k}: {v.item():.3e}"
if accelerator.is_local_main_process:
progress_bar.write(log_str)
@@ -773,7 +772,7 @@ def main():
log_str = ""
for k, v in val_logs.items():
log_str += "| {}: {:.3e}".format(k, v.item())
log_str += f"| {k}: {v.item():.3e}"
if accelerator.is_local_main_process:
progress_bar.write(log_str)

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +23,7 @@ import re
import sys
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import datasets
import evaluate
@@ -211,11 +210,11 @@ class DataTrainingArguments:
)
},
)
chars_to_ignore: Optional[List[str]] = list_field(
chars_to_ignore: Optional[list[str]] = list_field(
default=None,
metadata={"help": "A list of characters to remove from the transcripts."},
)
eval_metrics: List[str] = list_field(
eval_metrics: list[str] = list_field(
default=["wer"],
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
)
@@ -318,7 +317,7 @@ class DataCollatorCTCWithPadding:
pad_to_multiple_of_labels: Optional[int] = None
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +23,7 @@ import re
import sys
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import datasets
import evaluate
@@ -201,11 +200,11 @@ class DataTrainingArguments:
)
},
)
chars_to_ignore: Optional[List[str]] = list_field(
chars_to_ignore: Optional[list[str]] = list_field(
default=None,
metadata={"help": "A list of characters to remove from the transcripts."},
)
eval_metrics: List[str] = list_field(
eval_metrics: list[str] = list_field(
default=["wer"],
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
)
@@ -300,7 +299,7 @@ class DataCollatorCTCWithPadding:
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,7 +22,7 @@ import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import datasets
import evaluate
@@ -110,11 +109,11 @@ class ModelArguments:
freeze_encoder: bool = field(
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
)
forced_decoder_ids: List[List[int]] = field(
forced_decoder_ids: list[list[int]] = field(
default=None,
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
)
suppress_tokens: List[int] = field(
suppress_tokens: list[int] = field(
default=None,
metadata={
"help": (
@@ -247,7 +246,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
decoder_start_token_id: int
forward_attention_mask: bool
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
model_input_name = self.processor.model_input_names[0]

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -51,7 +50,7 @@ def get_results(output_dir):
results = {}
path = os.path.join(output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
with open(path) as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -87,7 +86,7 @@ def get_results(output_dir):
results = {}
path = os.path.join(output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
with open(path) as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,7 +20,7 @@ import os
import random
import sys
from dataclasses import dataclass, field
from typing import List, Optional
from typing import Optional
import datasets
import evaluate
@@ -256,7 +255,7 @@ class ModelArguments:
)
def get_label_list(raw_dataset, split="train") -> List[str]:
def get_label_list(raw_dataset, split="train") -> list[str]:
"""Get the list of labels from a multi-label dataset"""
if isinstance(raw_dataset[split]["label"][0], list):
@@ -537,7 +536,7 @@ def main():
model.config.id2label = {id: label for label, id in label_to_id.items()}
elif not is_regression: # classification, but not training
logger.info("using label infos in the model config")
logger.info("label2id: {}".format(model.config.label2id))
logger.info(f"label2id: {model.config.label2id}")
label_to_id = model.config.label2id
else: # regression
label_to_id = None
@@ -549,7 +548,7 @@ def main():
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def multi_labels_to_ids(labels: List[str]) -> List[float]:
def multi_labels_to_ids(labels: list[str]) -> list[float]:
ids = [0.0] * len(label_to_id) # BCELoss requires float as target type
for label in labels:
ids[label_to_id[label]] = 1.0
@@ -735,7 +734,7 @@ def main():
else:
item = label_list[item]
writer.write(f"{index}\t{item}\n")
logger.info("Predict results saved at {}".format(output_predict_file))
logger.info(f"Predict results saved at {output_predict_file}")
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
if training_args.push_to_hub:

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
@@ -19,7 +18,6 @@
import argparse
import inspect
import logging
from typing import Tuple
import torch
from accelerate import PartialState
@@ -271,8 +269,8 @@ class _ModelFallbackWrapper(GenerationMixin):
)
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> tuple[tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");