Implement head_mask for Flax BERT and other models copied from BERT (#14620)
* Implement head_mask for Flax BERT and other models copied from BERT * Remove `from jax._src.nn.functions import sigmoid` Remove `from jax._src.nn.functions import sigmoid` unintentionally added by IDE * Remove no more valid copy statement * Apply patil-suraj's suggestions from code review * Apply suggestions from the code review * Update Flax template * Fix a typo * Also update template for CausalLM modules
This commit is contained in:
@@ -111,6 +111,7 @@ class FlaxModelTesterMixin:
|
||||
all_model_classes = ()
|
||||
test_mismatched_shapes = True
|
||||
is_encoder_decoder = False
|
||||
test_head_masking = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
@@ -777,6 +778,53 @@ class FlaxModelTesterMixin:
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
def _prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
|
||||
if i == 0:
|
||||
return np.concatenate([np.zeros(1, dtype=jnp.int32), np.ones(attention_heads - 1, dtype=jnp.int32)])
|
||||
if i == num_hidden_layers - 1:
|
||||
return np.concatenate([np.zeros(attention_heads - 1, dtype=jnp.int32), np.ones(1, dtype=jnp.int32)])
|
||||
return np.ones(attention_heads, dtype=jnp.int32)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||
# Prepare head mask
|
||||
inputs["head_mask"] = np.stack(
|
||||
[
|
||||
_prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
outputs = model(**inputs)
|
||||
|
||||
def _check_attentions_validity(attentions):
|
||||
# Remove NaN
|
||||
for t in attentions:
|
||||
# Check we don't have more than 25% nans (arbitrary)
|
||||
self.assertLess(np.isnan(t).sum(), t.size / 4)
|
||||
attentions = [np.where(np.isnan(t), 0.0, t) for t in attentions]
|
||||
|
||||
self.assertAlmostEqual(attentions[0][..., 0, :, :].sum(), 0.0)
|
||||
self.assertNotEqual(attentions[0][..., -1, :, :].sum(), 0.0)
|
||||
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
|
||||
self.assertNotEqual(attentions[1][..., 0, :, :].sum(), 0.0)
|
||||
self.assertAlmostEqual(attentions[-1][..., -2, :, :].sum(), 0.0)
|
||||
self.assertNotEqual(attentions[-1][..., -1, :, :].sum(), 0.0)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
raise NotImplementedError("The test has not been implemented for encoder-decoder models yet.")
|
||||
else:
|
||||
_check_attentions_validity(outputs.attentions)
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user