fixed: hard coding for max and min number will out of range in fp16, which will cause nan.
This commit is contained in:
@@ -231,7 +231,7 @@ class PositionwiseFF(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadAttn(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
||||
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
||||
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
|
||||
super(MultiHeadAttn, self).__init__()
|
||||
|
||||
@@ -451,11 +451,19 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||
attn_mask = (attn_mask == 1) # Switch to bool
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[None,:,:,None], -65000).type_as(attn_score)
|
||||
else:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
|
||||
elif attn_mask.dim() == 3:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[:,:,:,None], -65000).type_as(attn_score)
|
||||
else:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_prob = F.softmax(attn_score, dim=1)
|
||||
@@ -587,7 +595,7 @@ class DecoderLayer(nn.Module):
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
||||
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
|
||||
@@ -607,7 +615,7 @@ class RelLearnableDecoderLayer(nn.Module):
|
||||
|
||||
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
|
||||
**kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
|
||||
@@ -628,7 +636,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
|
||||
|
||||
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
|
||||
d_head, dropout, **kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
|
||||
@@ -645,7 +653,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class AdaptiveEmbedding(nn.Module):
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||
sample_softmax=False):
|
||||
super(AdaptiveEmbedding, self).__init__()
|
||||
|
||||
@@ -683,7 +691,7 @@ class AdaptiveEmbedding(nn.Module):
|
||||
else:
|
||||
param = next(self.parameters())
|
||||
inp_flat = inp.view(-1)
|
||||
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
|
||||
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
|
||||
dtype=param.dtype, device=param.device)
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
@@ -852,7 +860,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
self.n_head = config.n_head
|
||||
self.d_head = config.d_head
|
||||
|
||||
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
|
||||
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
|
||||
div_val=config.div_val)
|
||||
|
||||
self.drop = nn.Dropout(config.dropout)
|
||||
@@ -1011,7 +1019,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
hids = []
|
||||
attentions = []
|
||||
if self.attn_type == 0: # default
|
||||
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
|
||||
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
|
||||
dtype=word_emb.dtype)
|
||||
if self.clamp_len > 0:
|
||||
pos_seq.clamp_(max=self.clamp_len)
|
||||
@@ -1165,7 +1173,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
|
||||
# use adaptive softmax (including standard softmax)
|
||||
else:
|
||||
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
||||
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
||||
config.cutoffs, div_val=config.div_val)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
@@ -140,7 +140,7 @@ class PreTrainedModel(nn.Module):
|
||||
Arguments:
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
@@ -434,7 +434,10 @@ class PoolerStartLogits(nn.Module):
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
x = x * (1 - p_mask) - 65500 * p_mask
|
||||
else:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user