fixed: hard coding for max and min number will out of range in fp16, which will cause nan.
This commit is contained in:
@@ -451,9 +451,17 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||
attn_mask = (attn_mask == 1) # Switch to bool
|
||||
if attn_mask.dim() == 2:
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[None,:,:,None], -65000).type_as(attn_score)
|
||||
else:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
|
||||
elif attn_mask.dim() == 3:
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[:,:,:,None], -65000).type_as(attn_score)
|
||||
else:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
|
||||
|
||||
|
||||
@@ -434,6 +434,9 @@ class PoolerStartLogits(nn.Module):
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
x = x * (1 - p_mask) - 65500 * p_mask
|
||||
else:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user