[Flax tests] fix test_model_outputs_equivalence (#15571)
* fix test_model_outputs_equivalence * fix tuple outputs for blenderbot
This commit is contained in:
@@ -134,10 +134,6 @@ class FlaxModelTesterMixin:
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
@@ -149,11 +145,9 @@ class FlaxModelTesterMixin:
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assert_almost_equals(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
|
||||
)
|
||||
self.assert_almost_equals(jnp.nan_to_num(tuple_object), jnp.nan_to_num(dict_object), 1e-5)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
Reference in New Issue
Block a user