replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm (#9386)
This commit is contained in:
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
except ImportError:
|
||||
pass
|
||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
|
||||
class BartLearnedPositionalEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
|
||||
@@ -321,13 +311,13 @@ class BartEncoderLayer(nn.Module):
|
||||
dropout=config.attention_dropout,
|
||||
)
|
||||
self.normalize_before = config.normalize_before
|
||||
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = BartLayerNorm(self.embed_dim)
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
||||
"""
|
||||
@@ -380,17 +370,17 @@ class BartDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.normalize_before = config.normalize_before
|
||||
|
||||
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = BartAttention(
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
)
|
||||
self.encoder_attn_layer_norm = BartLayerNorm(self.embed_dim)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = BartLayerNorm(self.embed_dim)
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -672,9 +662,9 @@ class BartEncoder(BartPretrainedModel):
|
||||
config.extra_pos_embeddings,
|
||||
)
|
||||
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = BartLayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
||||
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
||||
# mbart has one extra layer_norm
|
||||
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@@ -812,8 +802,8 @@ class BartDecoder(BartPretrainedModel):
|
||||
config.extra_pos_embeddings,
|
||||
)
|
||||
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layernorm_embedding = BartLayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
|
||||
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@@ -264,16 +264,6 @@ FSMT_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
have_fused_layer_norm = False
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
have_fused_layer_norm = True
|
||||
except ImportError:
|
||||
pass
|
||||
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm
|
||||
|
||||
|
||||
def invert_mask(attention_mask):
|
||||
"""Turns 1->0, 0->1, False->True, True-> False"""
|
||||
assert attention_mask.dim() == 2
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Dict, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput):
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
except ImportError:
|
||||
pass
|
||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
|
||||
class ProphetNetPreTrainedModel(PreTrainedModel):
|
||||
config_class = ProphetNetConfig
|
||||
base_model_prefix = "prophetnet"
|
||||
@@ -1044,11 +1035,11 @@ class ProphetNetEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
# 1st residual block
|
||||
self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads)
|
||||
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
|
||||
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
# 2nd residual block
|
||||
self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim)
|
||||
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
|
||||
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
# 1st residual block
|
||||
@@ -1073,16 +1064,16 @@ class ProphetNetDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
# 1st residual block
|
||||
self.self_attn = ProphetNetNgramProphetNetSelfAttention(config)
|
||||
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
|
||||
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
# 2nd residual block
|
||||
if config.add_cross_attention:
|
||||
self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads)
|
||||
self.cross_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
|
||||
self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
# 3rd residual block
|
||||
self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim)
|
||||
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
|
||||
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1154,7 +1145,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
|
||||
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
)
|
||||
self.position_embeddings = ProhpetNetPositionalEmbeddings(config)
|
||||
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
|
||||
self.embeddings_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
|
||||
|
||||
@@ -1274,7 +1265,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
||||
|
||||
self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
|
||||
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
|
||||
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
|
||||
self.embeddings_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user