Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -293,6 +293,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (
|
||||
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLMModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user