support 3D attention mask in bert (#32105)
* support 3D/4D attention mask in bert * test cases * update doc * fix doc
This commit is contained in:
@@ -498,6 +498,14 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_3d_mask_shapes(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# manipulate input_mask
|
||||
config_and_inputs = list(config_and_inputs)
|
||||
batch_size, seq_length = config_and_inputs[3].shape
|
||||
config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length])
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
@@ -535,6 +543,36 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def test_model_as_decoder_with_3d_input_mask(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
batch_size, seq_length = input_mask.shape
|
||||
input_mask = random_attention_mask([batch_size, seq_length, seq_length])
|
||||
batch_size, seq_length = encoder_attention_mask.shape
|
||||
encoder_attention_mask = random_attention_mask([batch_size, seq_length, seq_length])
|
||||
|
||||
self.model_tester.create_and_check_model_as_decoder(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user