Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -248,6 +248,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5ModelTester(self)
|
||||
@@ -417,6 +418,7 @@ class TFT5EncoderOnlyModelTester:
|
||||
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5EncoderOnlyModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user