Implement head_mask for Flax BERT and other models copied from BERT (#14620)

* Implement head_mask for Flax BERT and other models copied from BERT

* Remove `from jax._src.nn.functions import sigmoid`

Remove `from jax._src.nn.functions import sigmoid` unintentionally added by IDE

* Remove no more valid copy statement

* Apply patil-suraj's suggestions from code review

* Apply suggestions from the code review

* Update Flax template

* Fix a typo

* Also update template for CausalLM modules
This commit is contained in:
Daniel Stancl
2021-12-17 17:06:59 +01:00
committed by GitHub
parent 95119ad7b0
commit ff066119ca
9 changed files with 484 additions and 48 deletions

View File

@@ -118,6 +118,8 @@ class FlaxBertModelTester(unittest.TestCase):
@require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
test_head_masking = True
all_model_classes = (
(
FlaxBertModel,

View File

@@ -111,6 +111,7 @@ class FlaxModelTesterMixin:
all_model_classes = ()
test_mismatched_shapes = True
is_encoder_decoder = False
test_head_masking = False
def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = copy.deepcopy(inputs_dict)
@@ -777,6 +778,53 @@ class FlaxModelTesterMixin:
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
def test_headmasking(self):
if not self.test_head_masking:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
def _prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
if i == 0:
return np.concatenate([np.zeros(1, dtype=jnp.int32), np.ones(attention_heads - 1, dtype=jnp.int32)])
if i == num_hidden_layers - 1:
return np.concatenate([np.zeros(attention_heads - 1, dtype=jnp.int32), np.ones(1, dtype=jnp.int32)])
return np.ones(attention_heads, dtype=jnp.int32)
for model_class in self.all_model_classes:
model = model_class(config)
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
# Prepare head mask
inputs["head_mask"] = np.stack(
[
_prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
for i in range(config.num_hidden_layers)
]
)
outputs = model(**inputs)
def _check_attentions_validity(attentions):
# Remove NaN
for t in attentions:
# Check we don't have more than 25% nans (arbitrary)
self.assertLess(np.isnan(t).sum(), t.size / 4)
attentions = [np.where(np.isnan(t), 0.0, t) for t in attentions]
self.assertAlmostEqual(attentions[0][..., 0, :, :].sum(), 0.0)
self.assertNotEqual(attentions[0][..., -1, :, :].sum(), 0.0)
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
self.assertNotEqual(attentions[1][..., 0, :, :].sum(), 0.0)
self.assertAlmostEqual(attentions[-1][..., -2, :, :].sum(), 0.0)
self.assertNotEqual(attentions[-1][..., -1, :, :].sum(), 0.0)
if model.config.is_encoder_decoder:
raise NotImplementedError("The test has not been implemented for encoder-decoder models yet.")
else:
_check_attentions_validity(outputs.attentions)
@require_flax
@is_staging_test

View File

@@ -105,6 +105,8 @@ class FlaxElectraModelTester(unittest.TestCase):
@require_flax
class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
test_head_masking = True
all_model_classes = (
(
FlaxElectraModel,

View File

@@ -116,6 +116,8 @@ class FlaxRobertaModelTester(unittest.TestCase):
@require_flax
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
test_head_masking = True
all_model_classes = (
(
FlaxRobertaModel,