Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self)
|
||||
@@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self, base=True)
|
||||
|
||||
Reference in New Issue
Block a user