support 3D attention mask in bert (#32105)
* support 3D/4D attention mask in bert * test cases * update doc * fix doc
This commit is contained in:
@@ -908,7 +908,7 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
[`PreTrainedTokenizer.__call__`] for details.
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
[What are input IDs?](../glossary#input-ids)
|
||||||
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
- 1 for tokens that are **not masked**,
|
||||||
@@ -1023,7 +1023,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
the model is configured as a decoder.
|
the model is configured as a decoder.
|
||||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
|
||||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
@@ -1093,7 +1093,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Expand the attention mask
|
# Expand the attention mask
|
||||||
if use_sdpa_attention_masks:
|
if use_sdpa_attention_masks and attention_mask.dim() == 2:
|
||||||
# Expand the attention mask for SDPA.
|
# Expand the attention mask for SDPA.
|
||||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||||
if self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
@@ -1120,7 +1120,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
|
||||||
if use_sdpa_attention_masks:
|
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
|
||||||
# Expand the attention mask for SDPA.
|
# Expand the attention mask for SDPA.
|
||||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||||
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||||
|
|||||||
@@ -498,6 +498,14 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
config_and_inputs[0].position_embedding_type = type
|
config_and_inputs[0].position_embedding_type = type
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_3d_mask_shapes(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
# manipulate input_mask
|
||||||
|
config_and_inputs = list(config_and_inputs)
|
||||||
|
batch_size, seq_length = config_and_inputs[3].shape
|
||||||
|
config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length])
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
def test_model_as_decoder(self):
|
def test_model_as_decoder(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||||
@@ -535,6 +543,36 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_model_as_decoder_with_3d_input_mask(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
|
||||||
|
batch_size, seq_length = input_mask.shape
|
||||||
|
input_mask = random_attention_mask([batch_size, seq_length, seq_length])
|
||||||
|
batch_size, seq_length = encoder_attention_mask.shape
|
||||||
|
encoder_attention_mask = random_attention_mask([batch_size, seq_length, seq_length])
|
||||||
|
|
||||||
|
self.model_tester.create_and_check_model_as_decoder(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def test_for_causal_lm(self):
|
def test_for_causal_lm(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user