Use Python 3.9 syntax in tests (#37343)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-08 20:12:08 +08:00
committed by GitHub
parent 0fc683d1cd
commit 1e6b546ea6
666 changed files with 265 additions and 946 deletions

View File

@@ -17,7 +17,6 @@ import inspect
import json
import random
import tempfile
from typing import List, Tuple
import numpy as np
@@ -142,7 +141,7 @@ class FlaxModelTesterMixin:
return inputs_dict
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
diff = np.abs(a - b).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def test_model_outputs_equivalence(self):
@@ -153,7 +152,7 @@ class FlaxModelTesterMixin:
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
if isinstance(tuple_object, (list, tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
@@ -696,7 +695,7 @@ class FlaxModelTesterMixin:
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
f"Shapes of {k} do not match. Expecting {v.shape}, got {flat_params[k].shape}.",
)
# Check that setting params raises an ValueError when _do_init is False
@@ -722,7 +721,7 @@ class FlaxModelTesterMixin:
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
f"Shapes of {k} do not match. Expecting {v.shape}, got {flat_params[k].shape}.",
)
for model_class in self.all_model_classes:
@@ -801,7 +800,7 @@ class FlaxModelTesterMixin:
self.assertEqual(len(state_file), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
with open(index_file, encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())