model fixes + ipnb fixes
This commit is contained in:
@@ -377,12 +377,17 @@ class BertModel(nn.Module):
|
||||
self.encoder = BERTEncoder(config)
|
||||
self.pooler = BERTPooler(config)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
|
||||
# We create 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, from_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
|
||||
# It's more simple than the triangular masking of causal attention, just need to
|
||||
# prepare the broadcast here
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user