Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -240,6 +240,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFAlbertModelTester(self)
|
||||
|
||||
@@ -108,10 +108,11 @@ class TFBartModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -144,6 +145,8 @@ def prepare_bart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -155,11 +158,17 @@ def prepare_bart_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -169,6 +178,7 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBartModelTester(self)
|
||||
|
||||
@@ -273,6 +273,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -107,10 +107,11 @@ class TFBlenderbotModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotModelTester(self)
|
||||
|
||||
@@ -107,10 +107,11 @@ class TFBlenderbotSmallModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -143,6 +144,8 @@ def prepare_blenderbot_small_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -154,11 +157,17 @@ def prepare_blenderbot_small_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -170,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||
|
||||
@@ -440,6 +440,11 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_train_pipeline_custom_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# head_mask and decoder_head_mask has different shapes than other input args
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
if "decoder_head_mask" in inputs_dict:
|
||||
del inputs_dict["decoder_head_mask"]
|
||||
tf_main_layer_classes = set(
|
||||
module_member
|
||||
for model_class in self.all_model_classes
|
||||
@@ -620,6 +625,75 @@ class TFModelTesterMixin:
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
|
||||
random.Random().seed(42)
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
random.Random().seed()
|
||||
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
|
||||
# Prepare head_mask
|
||||
def prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
|
||||
if i == 0:
|
||||
return tf.concat(
|
||||
(tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0
|
||||
)
|
||||
elif i == num_hidden_layers - 1:
|
||||
return tf.concat(
|
||||
(tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0
|
||||
)
|
||||
else:
|
||||
return tf.ones(attention_heads, dtype=tf.float32)
|
||||
|
||||
head_mask = tf.stack(
|
||||
[
|
||||
prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
|
||||
for i in range(config.num_hidden_layers)
|
||||
],
|
||||
0,
|
||||
)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||
inputs["head_mask"] = head_mask
|
||||
if model.config.is_encoder_decoder:
|
||||
signature = inspect.signature(model.call)
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
def check_attentions_validity(attentions):
|
||||
# Remove Nan
|
||||
for t in attentions:
|
||||
self.assertLess(
|
||||
(tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy()
|
||||
) # Check we don't have more than 25% nans (arbitrary)
|
||||
|
||||
attentions = [
|
||||
tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions
|
||||
] # remove them (the test is less complete)
|
||||
|
||||
self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0)
|
||||
self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0)
|
||||
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
|
||||
self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0)
|
||||
self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0)
|
||||
self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
check_attentions_validity(outputs.encoder_attentions)
|
||||
check_attentions_validity(outputs.decoder_attentions)
|
||||
else:
|
||||
check_attentions_validity(outputs.attentions)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -173,6 +173,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFCTRLModelTester(self)
|
||||
|
||||
@@ -183,6 +183,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else None
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFDistilBertModelTester(self)
|
||||
|
||||
@@ -205,6 +205,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFElectraModelTester(self)
|
||||
|
||||
@@ -291,6 +291,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (
|
||||
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFlaubertModelTester(self)
|
||||
|
||||
@@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self)
|
||||
@@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self, base=True)
|
||||
|
||||
@@ -332,6 +332,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFGPT2ModelTester(self)
|
||||
|
||||
@@ -187,6 +187,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLEDModelTester(self)
|
||||
|
||||
@@ -297,6 +297,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLongformerModelTester(self)
|
||||
|
||||
@@ -361,6 +361,7 @@ class TFLxmertModelTester(object):
|
||||
class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLxmertModelTester(self)
|
||||
|
||||
@@ -109,10 +109,11 @@ class TFMarianModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -145,6 +146,8 @@ def prepare_marian_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -156,11 +159,17 @@ def prepare_marian_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -170,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMarianModelTester(self)
|
||||
|
||||
@@ -106,10 +106,11 @@ class TFMBartModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -147,6 +148,8 @@ def prepare_mbart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -158,11 +161,17 @@ def prepare_mbart_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -172,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMBartModelTester(self)
|
||||
|
||||
@@ -55,6 +55,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
class TFMobileBertModelTester(object):
|
||||
def __init__(
|
||||
|
||||
@@ -198,6 +198,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMPNetModelTester(self)
|
||||
|
||||
@@ -202,6 +202,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (
|
||||
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFOpenAIGPTModelTester(self)
|
||||
|
||||
@@ -107,10 +107,11 @@ class TFPegasusModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -143,6 +144,8 @@ def prepare_pegasus_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -154,11 +157,17 @@ def prepare_pegasus_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -168,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFPegasusModelTester(self)
|
||||
|
||||
@@ -185,6 +185,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFRobertaModelTester(self)
|
||||
|
||||
@@ -248,6 +248,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5ModelTester(self)
|
||||
@@ -417,6 +418,7 @@ class TFT5EncoderOnlyModelTester:
|
||||
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5EncoderOnlyModelTester(self)
|
||||
|
||||
@@ -163,6 +163,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = () if is_tf_available() else ()
|
||||
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFTransfoXLModelTester(self)
|
||||
|
||||
@@ -293,6 +293,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (
|
||||
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLMModelTester(self)
|
||||
|
||||
@@ -347,6 +347,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (
|
||||
(TFXLNetLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLNetModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user