support 3D attention mask in bert (#32105)

* support 3D/4D attention mask in bert

* test cases

* update doc

* fix doc
This commit is contained in:
Shiyu
2024-09-06 20:20:48 +08:00
committed by GitHub
parent 2b18354106
commit 342e800086
2 changed files with 42 additions and 4 deletions

View File

@@ -908,7 +908,7 @@ BERT_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details. [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids) [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
@@ -1023,7 +1023,7 @@ class BertModel(BertPreTrainedModel):
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder. the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
@@ -1093,7 +1093,7 @@ class BertModel(BertPreTrainedModel):
) )
# Expand the attention mask # Expand the attention mask
if use_sdpa_attention_masks: if use_sdpa_attention_masks and attention_mask.dim() == 2:
# Expand the attention mask for SDPA. # Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
if self.config.is_decoder: if self.config.is_decoder:
@@ -1120,7 +1120,7 @@ class BertModel(BertPreTrainedModel):
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if use_sdpa_attention_masks: if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
# Expand the attention mask for SDPA. # Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(

View File

@@ -498,6 +498,14 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs[0].position_embedding_type = type config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_3d_mask_shapes(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# manipulate input_mask
config_and_inputs = list(config_and_inputs)
batch_size, seq_length = config_and_inputs[3].shape
config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length])
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_as_decoder(self): def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
@@ -535,6 +543,36 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
encoder_attention_mask, encoder_attention_mask,
) )
def test_model_as_decoder_with_3d_input_mask(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = self.model_tester.prepare_config_and_inputs_for_decoder()
batch_size, seq_length = input_mask.shape
input_mask = random_attention_mask([batch_size, seq_length, seq_length])
batch_size, seq_length = encoder_attention_mask.shape
encoder_attention_mask = random_attention_mask([batch_size, seq_length, seq_length])
self.model_tester.create_and_check_model_as_decoder(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def test_for_causal_lm(self): def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)