Add head_mask/decoder_head_mask for TF BART models (#9639)

* Add head_mask/decoder_head_mask for TF BART models

* Add head_mask and decoder_head_mask input arguments for TF BART-based
models as a TF counterpart to the PR #9569

* Add test_headmasking functionality to tests/test_modeling_tf_common.py

* TODO: Add a test to verify that we can get a gradient back for
importance score computation

* Remove redundant #TODO note

Remove redundant #TODO note from tests/test_modeling_tf_common.py

* Fix assertions

* Make style

* Fix ...Model input args and adjust one new test

* Add back head_mask and decoder_head_mask to BART-based ...Model
after the last commit

* Remove head_mask ande decoder_head_mask from input_dict
in TF test_train_pipeline_custom_model as these two have different
shape than other input args (Necessary for passing this test)

* Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
Daniel Stancl
2021-01-26 09:50:00 +01:00
committed by GitHub
parent cb73ab5a38
commit 1867d9a8d7
32 changed files with 849 additions and 36 deletions

View File

@@ -240,6 +240,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFAlbertModelTester(self)

View File

@@ -108,10 +108,11 @@ class TFBartModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
@@ -144,6 +145,8 @@ def prepare_bart_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -155,11 +158,17 @@ def prepare_bart_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": head_mask,
}
@@ -169,6 +178,7 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True
def setUp(self):
self.model_tester = TFBartModelTester(self)

View File

@@ -273,6 +273,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@@ -107,10 +107,11 @@ class TFBlenderbotModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
@@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True
def setUp(self):
self.model_tester = TFBlenderbotModelTester(self)

View File

@@ -107,10 +107,11 @@ class TFBlenderbotSmallModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
@@ -143,6 +144,8 @@ def prepare_blenderbot_small_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -154,11 +157,17 @@ def prepare_blenderbot_small_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -170,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True
def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)

View File

@@ -440,6 +440,11 @@ class TFModelTesterMixin:
def test_train_pipeline_custom_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# head_mask and decoder_head_mask has different shapes than other input args
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
if "decoder_head_mask" in inputs_dict:
del inputs_dict["decoder_head_mask"]
tf_main_layer_classes = set(
module_member
for model_class in self.all_model_classes
@@ -620,6 +625,75 @@ class TFModelTesterMixin:
self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs)
def test_headmasking(self):
if not self.test_head_masking:
return
random.Random().seed(42)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
random.Random().seed()
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Prepare head_mask
def prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
if i == 0:
return tf.concat(
(tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0
)
elif i == num_hidden_layers - 1:
return tf.concat(
(tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0
)
else:
return tf.ones(attention_heads, dtype=tf.float32)
head_mask = tf.stack(
[
prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
for i in range(config.num_hidden_layers)
],
0,
)
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
inputs["head_mask"] = head_mask
if model.config.is_encoder_decoder:
signature = inspect.signature(model.call)
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True)
def check_attentions_validity(attentions):
# Remove Nan
for t in attentions:
self.assertLess(
(tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy()
) # Check we don't have more than 25% nans (arbitrary)
attentions = [
tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions
] # remove them (the test is less complete)
self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0)
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0)
self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0)
if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions)
else:
check_attentions_validity(outputs.attentions)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -173,6 +173,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
test_head_masking = False
def setUp(self):
self.model_tester = TFCTRLModelTester(self)

View File

@@ -183,6 +183,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else None
)
test_head_masking = False
def setUp(self):
self.model_tester = TFDistilBertModelTester(self)

View File

@@ -205,6 +205,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFElectraModelTester(self)

View File

@@ -291,6 +291,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
def setUp(self):
self.model_tester = TFFlaubertModelTester(self)

View File

@@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFFunnelModelTester(self)
@@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFFunnelModelTester(self, base=True)

View File

@@ -332,6 +332,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
test_head_masking = False
def setUp(self):
self.model_tester = TFGPT2ModelTester(self)

View File

@@ -187,6 +187,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = TFLEDModelTester(self)

View File

@@ -297,6 +297,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFLongformerModelTester(self)

View File

@@ -361,6 +361,7 @@ class TFLxmertModelTester(object):
class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
test_head_masking = False
def setUp(self):
self.model_tester = TFLxmertModelTester(self)

View File

@@ -109,10 +109,11 @@ class TFMarianModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
@@ -145,6 +146,8 @@ def prepare_marian_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -156,11 +159,17 @@ def prepare_marian_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -170,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True
def setUp(self):
self.model_tester = TFMarianModelTester(self)

View File

@@ -106,10 +106,11 @@ class TFMBartModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
@@ -147,6 +148,8 @@ def prepare_mbart_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -158,11 +161,17 @@ def prepare_mbart_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": head_mask,
}
@@ -172,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True
def setUp(self):
self.model_tester = TFMBartModelTester(self)

View File

@@ -55,6 +55,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
class TFMobileBertModelTester(object):
def __init__(

View File

@@ -198,6 +198,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFMPNetModelTester(self)

View File

@@ -202,6 +202,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_head_masking = False
def setUp(self):
self.model_tester = TFOpenAIGPTModelTester(self)

View File

@@ -107,10 +107,11 @@ class TFPegasusModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
@@ -143,6 +144,8 @@ def prepare_pegasus_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -154,11 +157,17 @@ def prepare_pegasus_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -168,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True
def setUp(self):
self.model_tester = TFPegasusModelTester(self)

View File

@@ -185,6 +185,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFRobertaModelTester(self)

View File

@@ -248,6 +248,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
def setUp(self):
self.model_tester = TFT5ModelTester(self)
@@ -417,6 +418,7 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
def setUp(self):
self.model_tester = TFT5EncoderOnlyModelTester(self)

View File

@@ -163,6 +163,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = () if is_tf_available() else ()
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = TFTransfoXLModelTester(self)

View File

@@ -293,6 +293,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
def setUp(self):
self.model_tester = TFXLMModelTester(self)

View File

@@ -347,6 +347,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(TFXLNetLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
def setUp(self):
self.model_tester = TFXLNetModelTester(self)