Fix missing output_attentions in PT/Flax equivalence test (#16271)
* fix - set output_attentions to True * Update tests/test_modeling_flax_common.py * update for has_attentions * overwrite check_outputs in FlaxBigBirdModelTest Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -190,3 +190,12 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
|
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
# overwrite from common in order to skip the check on `attentions`
|
||||||
|
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||||
|
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||||
|
# an effort was done to return `attention_probs` (yet to be verified).
|
||||||
|
if type(names) == str and names.startswith("attentions"):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ class FlaxModelTesterMixin:
|
|||||||
test_mismatched_shapes = True
|
test_mismatched_shapes = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
has_attentions = True
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class):
|
def _prepare_for_class(self, inputs_dict, model_class):
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
@@ -168,6 +169,7 @@ class FlaxModelTesterMixin:
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs)
|
||||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -204,8 +206,7 @@ class FlaxModelTesterMixin:
|
|||||||
pt_outputs[pt_nans] = 0
|
pt_outputs[pt_nans] = 0
|
||||||
fx_outputs[pt_nans] = 0
|
fx_outputs[pt_nans] = 0
|
||||||
|
|
||||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
|
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
|
||||||
@@ -222,6 +223,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# Output all for aggressive testing
|
# Output all for aggressive testing
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
|
config.output_attentions = self.has_attentions
|
||||||
|
|
||||||
# prepare inputs
|
# prepare inputs
|
||||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
@@ -274,7 +276,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# Output all for aggressive testing
|
# Output all for aggressive testing
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
# Pure convolutional models have no attention
|
config.output_attentions = self.has_attentions
|
||||||
|
|
||||||
# prepare inputs
|
# prepare inputs
|
||||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
@@ -314,6 +316,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# send pytorch model to the correct device
|
# send pytorch model to the correct device
|
||||||
pt_model_loaded.to(torch_device)
|
pt_model_loaded.to(torch_device)
|
||||||
|
pt_model_loaded.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user