From a6885db91224fc26042b0a2fd6c7b6a827c0f5a6 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 9 Feb 2022 12:26:48 +0100 Subject: [PATCH] [Flax tests] fix test_model_outputs_equivalence (#15571) * fix test_model_outputs_equivalence * fix tuple outputs for blenderbot --- .../models/blenderbot/modeling_flax_blenderbot.py | 4 ++-- tests/test_modeling_flax_common.py | 10 ++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 98693befa9..a6508ac274 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -719,7 +719,7 @@ class FlaxBlenderbotEncoder(nn.Module): last_hidden_states = self.layer_norm(last_hidden_states) if not return_dict: - return outputs + return (last_hidden_states,) + outputs[1:] return FlaxBaseModelOutput( last_hidden_state=last_hidden_states, @@ -797,7 +797,7 @@ class FlaxBlenderbotDecoder(nn.Module): last_hidden_states = self.layer_norm(last_hidden_states) if not return_dict: - return outputs + return (last_hidden_states,) + outputs[1:] return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=last_hidden_states, diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index b1d15b6673..1edd41aab0 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -134,10 +134,6 @@ class FlaxModelTesterMixin: def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - def set_nan_tensor_to_zero(t): - t[t != t] = 0 - return t - def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() @@ -149,11 +145,9 @@ class FlaxModelTesterMixin: elif tuple_object is None: return else: - self.assert_almost_equals( - set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5 - ) + self.assert_almost_equals(jnp.nan_to_num(tuple_object), jnp.nan_to_num(dict_object), 1e-5) - recursive_check(tuple_output, dict_output) + recursive_check(tuple_output, dict_output) for model_class in self.all_model_classes: model = model_class(config)