fix typo in comments (#6838)
This commit is contained in:
committed by
GitHub
parent
7351ef83c1
commit
6b24281229
@@ -803,8 +803,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||||
|
|
||||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
|||||||
Reference in New Issue
Block a user