Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -440,6 +440,11 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_train_pipeline_custom_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# head_mask and decoder_head_mask has different shapes than other input args
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
if "decoder_head_mask" in inputs_dict:
|
||||
del inputs_dict["decoder_head_mask"]
|
||||
tf_main_layer_classes = set(
|
||||
module_member
|
||||
for model_class in self.all_model_classes
|
||||
@@ -620,6 +625,75 @@ class TFModelTesterMixin:
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
|
||||
random.Random().seed(42)
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
random.Random().seed()
|
||||
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
|
||||
# Prepare head_mask
|
||||
def prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
|
||||
if i == 0:
|
||||
return tf.concat(
|
||||
(tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0
|
||||
)
|
||||
elif i == num_hidden_layers - 1:
|
||||
return tf.concat(
|
||||
(tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0
|
||||
)
|
||||
else:
|
||||
return tf.ones(attention_heads, dtype=tf.float32)
|
||||
|
||||
head_mask = tf.stack(
|
||||
[
|
||||
prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
|
||||
for i in range(config.num_hidden_layers)
|
||||
],
|
||||
0,
|
||||
)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||
inputs["head_mask"] = head_mask
|
||||
if model.config.is_encoder_decoder:
|
||||
signature = inspect.signature(model.call)
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
def check_attentions_validity(attentions):
|
||||
# Remove Nan
|
||||
for t in attentions:
|
||||
self.assertLess(
|
||||
(tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy()
|
||||
) # Check we don't have more than 25% nans (arbitrary)
|
||||
|
||||
attentions = [
|
||||
tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions
|
||||
] # remove them (the test is less complete)
|
||||
|
||||
self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0)
|
||||
self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0)
|
||||
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
|
||||
self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0)
|
||||
self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0)
|
||||
self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
check_attentions_validity(outputs.encoder_attentions)
|
||||
check_attentions_validity(outputs.decoder_attentions)
|
||||
else:
|
||||
check_attentions_validity(outputs.attentions)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user