Use Python 3.9 syntax in tests (#37343)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -4923,8 +4922,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
def get_commit_history(self, repo):
|
||||
commit_logs = subprocess.run(
|
||||
"git log".split(),
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
encoding="utf-8",
|
||||
cwd=repo,
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -148,7 +147,7 @@ if __name__ == "__main__":
|
||||
for dataset_length in [101, 40, 7]:
|
||||
dataset = DummyDataset(dataset_length)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
def compute_metrics(p: EvalPrediction) -> dict:
|
||||
sequential = list(range(len(dataset)))
|
||||
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
|
||||
if not success and training_args.local_rank == 0:
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
@@ -166,7 +165,7 @@ if __name__ == "__main__":
|
||||
device = torch.device(torch.distributed.get_rank())
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to(device)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict[str, bool]:
|
||||
def compute_metrics(p: EvalPrediction) -> dict[str, bool]:
|
||||
return {"accuracy": (p.predictions == p.label_ids).mean()}
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
#
|
||||
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.utils import logging
|
||||
@@ -79,7 +78,7 @@ def main():
|
||||
for dataset_length in [1001, 256, 15]:
|
||||
dataset = DummyDataset(dataset_length)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
def compute_metrics(p: EvalPrediction) -> dict:
|
||||
sequential = list(range(len(dataset)))
|
||||
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
|
||||
return {"success": success}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
Reference in New Issue
Block a user