Use Python 3.9 syntax in tests (#37343)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -25,7 +24,6 @@ import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
@@ -2333,10 +2331,10 @@ class ModelTesterMixin:
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
if isinstance(tuple_object, (list, tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
elif isinstance(tuple_object, dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(
|
||||
tuple_object.values(), dict_object.values()
|
||||
):
|
||||
@@ -2455,7 +2453,7 @@ class ModelTesterMixin:
|
||||
return new_tf_outputs, new_pt_outputs
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
diff = np.abs(a - b).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
@@ -2656,7 +2654,7 @@ class ModelTesterMixin:
|
||||
for value, parallel_value in zip(output, parallel_output):
|
||||
if isinstance(value, torch.Tensor):
|
||||
torch.testing.assert_close(value, parallel_value.to("cpu"), rtol=1e-7, atol=1e-7)
|
||||
elif isinstance(value, (Tuple, List)):
|
||||
elif isinstance(value, (tuple, list)):
|
||||
for value_, parallel_value_ in zip(value, parallel_value):
|
||||
torch.testing.assert_close(value_, parallel_value_.to("cpu"), rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user