fixed: hard coding for max and min number will out of range in fp16, which will cause nan.

This commit is contained in:
Zili Wang
2019-09-11 15:41:53 +08:00
parent 7424b2848f
commit 8bdee1cb73
2 changed files with 26 additions and 15 deletions

View File

@@ -451,11 +451,19 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if attn_mask is not None and torch.sum(attn_mask).item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill( if next(self.parameters()).dtype == torch.float16:
attn_mask[None,:,:,None], -1e30).type_as(attn_score) attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill( if next(self.parameters()).dtype == torch.float16:
attn_mask[:,:,:,None], -1e30).type_as(attn_score) attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)

View File

@@ -434,7 +434,10 @@ class PoolerStartLogits(nn.Module):
x = self.dense(hidden_states).squeeze(-1) x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None: if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x return x