model fixes + ipnb fixes

This commit is contained in:
thomwolf
2018-11-02 15:11:16 +01:00
parent 3ff2ec5eb3
commit c84315ec35
3 changed files with 867 additions and 54 deletions

View File

@@ -377,12 +377,17 @@ class BertModel(nn.Module):
self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config)
def forward(self, input_ids, token_type_ids, attention_mask):
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
# We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0