Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -107,10 +107,11 @@ class TFBlenderbotModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user