fixed: hard coding for max and min number will out of range in fp16, which will cause nan.

This commit is contained in:
Zili Wang
2019-09-11 15:41:53 +08:00
parent 7424b2848f
commit 8bdee1cb73
2 changed files with 26 additions and 15 deletions

View File

@@ -231,7 +231,7 @@ class PositionwiseFF(nn.Module):
class MultiHeadAttn(nn.Module): class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False): pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
super(MultiHeadAttn, self).__init__() super(MultiHeadAttn, self).__init__()
@@ -451,11 +451,19 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if attn_mask is not None and torch.sum(attn_mask).item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill( if next(self.parameters()).dtype == torch.float16:
attn_mask[None,:,:,None], -1e30).type_as(attn_score) attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill( if next(self.parameters()).dtype == torch.float16:
attn_mask[:,:,:,None], -1e30).type_as(attn_score) attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
@@ -587,7 +595,7 @@ class DecoderLayer(nn.Module):
super(DecoderLayer, self).__init__() super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
@@ -607,7 +615,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
@@ -628,7 +636,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None): def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
@@ -645,7 +653,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
class AdaptiveEmbedding(nn.Module): class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False): sample_softmax=False):
super(AdaptiveEmbedding, self).__init__() super(AdaptiveEmbedding, self).__init__()
@@ -683,7 +691,7 @@ class AdaptiveEmbedding(nn.Module):
else: else:
param = next(self.parameters()) param = next(self.parameters())
inp_flat = inp.view(-1) inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device) dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
@@ -852,7 +860,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.n_head = config.n_head self.n_head = config.n_head
self.d_head = config.d_head self.d_head = config.d_head
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val) div_val=config.div_val)
self.drop = nn.Dropout(config.dropout) self.drop = nn.Dropout(config.dropout)
@@ -1011,7 +1019,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
hids = [] hids = []
attentions = [] attentions = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
if self.clamp_len > 0: if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len) pos_seq.clamp_(max=self.clamp_len)
@@ -1165,7 +1173,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
# use adaptive softmax (including standard softmax) # use adaptive softmax (including standard softmax)
else: else:
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val) config.cutoffs, div_val=config.div_val)
self.init_weights() self.init_weights()
self.tie_weights() self.tie_weights()

View File

@@ -140,7 +140,7 @@ class PreTrainedModel(nn.Module):
Arguments: Arguments:
new_num_tokens: (`optional`) int: new_num_tokens: (`optional`) int:
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
Return: ``torch.nn.Embeddings`` Return: ``torch.nn.Embeddings``
@@ -434,7 +434,10 @@ class PoolerStartLogits(nn.Module):
x = self.dense(hidden_states).squeeze(-1) x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None: if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x return x