From 8bdee1cb733d3c735c64a1ab5e7b41cec02d8fa3 Mon Sep 17 00:00:00 2001 From: Zili Wang Date: Wed, 11 Sep 2019 15:41:53 +0800 Subject: [PATCH] fixed: hard coding for max and min number will out of range in fp16, which will cause nan. --- pytorch_transformers/modeling_transfo_xl.py | 34 +++++++++++++-------- pytorch_transformers/modeling_utils.py | 7 +++-- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 7ad9d10891..228f75fc82 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -231,7 +231,7 @@ class PositionwiseFF(nn.Module): class MultiHeadAttn(nn.Module): - def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, + def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False): super(MultiHeadAttn, self).__init__() @@ -451,11 +451,19 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): if attn_mask is not None and torch.sum(attn_mask).item(): attn_mask = (attn_mask == 1) # Switch to bool if attn_mask.dim() == 2: - attn_score = attn_score.float().masked_fill( - attn_mask[None,:,:,None], -1e30).type_as(attn_score) + if next(self.parameters()).dtype == torch.float16: + attn_score = attn_score.float().masked_fill( + attn_mask[None,:,:,None], -65000).type_as(attn_score) + else: + attn_score = attn_score.float().masked_fill( + attn_mask[None,:,:,None], -1e30).type_as(attn_score) elif attn_mask.dim() == 3: - attn_score = attn_score.float().masked_fill( - attn_mask[:,:,:,None], -1e30).type_as(attn_score) + if next(self.parameters()).dtype == torch.float16: + attn_score = attn_score.float().masked_fill( + attn_mask[:,:,:,None], -65000).type_as(attn_score) + else: + attn_score = attn_score.float().masked_fill( + attn_mask[:,:,:,None], -1e30).type_as(attn_score) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) @@ -587,7 +595,7 @@ class DecoderLayer(nn.Module): super(DecoderLayer, self).__init__() self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) - self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, + self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None): @@ -607,7 +615,7 @@ class RelLearnableDecoderLayer(nn.Module): self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) - self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, + self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None): @@ -628,7 +636,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) - self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, + self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None): @@ -645,7 +653,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): class AdaptiveEmbedding(nn.Module): - def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, + def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False): super(AdaptiveEmbedding, self).__init__() @@ -683,7 +691,7 @@ class AdaptiveEmbedding(nn.Module): else: param = next(self.parameters()) inp_flat = inp.view(-1) - emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], + emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] @@ -852,7 +860,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): self.n_head = config.n_head self.d_head = config.d_head - self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, + self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) self.drop = nn.Dropout(config.dropout) @@ -1011,7 +1019,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): hids = [] attentions = [] if self.attn_type == 0: # default - pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, + pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) @@ -1165,7 +1173,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) # use adaptive softmax (including standard softmax) else: - self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, + self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) self.init_weights() self.tie_weights() diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 2fb4671674..25aeefe10f 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -140,7 +140,7 @@ class PreTrainedModel(nn.Module): Arguments: new_num_tokens: (`optional`) int: - New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. + New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. Return: ``torch.nn.Embeddings`` @@ -434,7 +434,10 @@ class PoolerStartLogits(nn.Module): x = self.dense(hidden_states).squeeze(-1) if p_mask is not None: - x = x * (1 - p_mask) - 1e30 * p_mask + if next(self.parameters()).dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask return x