[Flax tests] fix test_model_outputs_equivalence (#15571)
* fix test_model_outputs_equivalence * fix tuple outputs for blenderbot
This commit is contained in:
@@ -719,7 +719,7 @@ class FlaxBlenderbotEncoder(nn.Module):
|
|||||||
last_hidden_states = self.layer_norm(last_hidden_states)
|
last_hidden_states = self.layer_norm(last_hidden_states)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return outputs
|
return (last_hidden_states,) + outputs[1:]
|
||||||
|
|
||||||
return FlaxBaseModelOutput(
|
return FlaxBaseModelOutput(
|
||||||
last_hidden_state=last_hidden_states,
|
last_hidden_state=last_hidden_states,
|
||||||
@@ -797,7 +797,7 @@ class FlaxBlenderbotDecoder(nn.Module):
|
|||||||
last_hidden_states = self.layer_norm(last_hidden_states)
|
last_hidden_states = self.layer_norm(last_hidden_states)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return outputs
|
return (last_hidden_states,) + outputs[1:]
|
||||||
|
|
||||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=last_hidden_states,
|
last_hidden_state=last_hidden_states,
|
||||||
|
|||||||
@@ -134,10 +134,6 @@ class FlaxModelTesterMixin:
|
|||||||
def test_model_outputs_equivalence(self):
|
def test_model_outputs_equivalence(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
def set_nan_tensor_to_zero(t):
|
|
||||||
t[t != t] = 0
|
|
||||||
return t
|
|
||||||
|
|
||||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||||
@@ -149,11 +145,9 @@ class FlaxModelTesterMixin:
|
|||||||
elif tuple_object is None:
|
elif tuple_object is None:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.assert_almost_equals(
|
self.assert_almost_equals(jnp.nan_to_num(tuple_object), jnp.nan_to_num(dict_object), 1e-5)
|
||||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
|
|
||||||
)
|
|
||||||
|
|
||||||
recursive_check(tuple_output, dict_output)
|
recursive_check(tuple_output, dict_output)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user