Use Python 3.9 syntax in tests (#37343)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -17,7 +17,6 @@ import inspect
|
||||
import json
|
||||
import random
|
||||
import tempfile
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -142,7 +141,7 @@ class FlaxModelTesterMixin:
|
||||
return inputs_dict
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
diff = np.abs(a - b).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
@@ -153,7 +152,7 @@ class FlaxModelTesterMixin:
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
if isinstance(tuple_object, (list, tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
@@ -696,7 +695,7 @@ class FlaxModelTesterMixin:
|
||||
self.assertEqual(
|
||||
v.shape,
|
||||
flat_params[k].shape,
|
||||
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||
f"Shapes of {k} do not match. Expecting {v.shape}, got {flat_params[k].shape}.",
|
||||
)
|
||||
|
||||
# Check that setting params raises an ValueError when _do_init is False
|
||||
@@ -722,7 +721,7 @@ class FlaxModelTesterMixin:
|
||||
self.assertEqual(
|
||||
v.shape,
|
||||
flat_params[k].shape,
|
||||
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||
f"Shapes of {k} do not match. Expecting {v.shape}, got {flat_params[k].shape}.",
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -801,7 +800,7 @@ class FlaxModelTesterMixin:
|
||||
self.assertEqual(len(state_file), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
with open(index_file, encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
|
||||
Reference in New Issue
Block a user