From 47ca0eaaac5b993da1a22e190f8c938aa647d73b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jan 2021 10:00:08 -0800 Subject: [PATCH] replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm (#9386) --- src/transformers/models/bart/modeling_bart.py | 30 +++++++------------ src/transformers/models/fsmt/modeling_fsmt.py | 12 +------- .../models/prophetnet/modeling_prophetnet.py | 25 +++++----------- 3 files changed, 19 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7f4af885d5..654de6a1c4 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -22,7 +22,7 @@ import numpy as np import torch import torch.nn.functional as F from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN from ...file_utils import ( @@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) -def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True): - try: - from apex.normalization import FusedLayerNorm - - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) - except ImportError: - pass - return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) - - class BartLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting @@ -321,13 +311,13 @@ class BartEncoderLayer(nn.Module): dropout=config.attention_dropout, ) self.normalize_before = config.normalize_before - self.self_attn_layer_norm = BartLayerNorm(self.embed_dim) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = BartLayerNorm(self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): """ @@ -380,17 +370,17 @@ class BartDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.normalize_before = config.normalize_before - self.self_attn_layer_norm = BartLayerNorm(self.embed_dim) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.encoder_attn = BartAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, ) - self.encoder_attn_layer_norm = BartLayerNorm(self.embed_dim) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = BartLayerNorm(self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) def forward( self, @@ -672,9 +662,9 @@ class BartEncoder(BartPretrainedModel): config.extra_pos_embeddings, ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self.layernorm_embedding = BartLayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() + self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() # mbart has one extra layer_norm - self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None + self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None self.init_weights() @@ -812,8 +802,8 @@ class BartDecoder(BartPretrainedModel): config.extra_pos_embeddings, ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self.layernorm_embedding = BartLayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() - self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None + self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() + self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None self.init_weights() diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 0cd07ed6e6..58d7ce9000 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -34,7 +34,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN from ...file_utils import ( @@ -264,16 +264,6 @@ FSMT_INPUTS_DOCSTRING = r""" """ -have_fused_layer_norm = False -try: - from apex.normalization import FusedLayerNorm - - have_fused_layer_norm = True -except ImportError: - pass -LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm - - def invert_mask(attention_mask): """Turns 1->0, 0->1, False->True, True-> False""" assert attention_mask.dim() == 2 diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 17db02b5b2..ffc26a9721 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -23,6 +23,7 @@ from typing import Dict, Optional, Tuple import torch import torch.nn.functional as F from torch import Tensor, nn +from torch.nn import LayerNorm from ...activations import ACT2FN from ...file_utils import ( @@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput): cross_attentions: Optional[Tuple[torch.FloatTensor]] = None -def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): - try: - from apex.normalization import FusedLayerNorm - - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) - except ImportError: - pass - return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) - - class ProphetNetPreTrainedModel(PreTrainedModel): config_class = ProphetNetConfig base_model_prefix = "prophetnet" @@ -1044,11 +1035,11 @@ class ProphetNetEncoderLayer(nn.Module): super().__init__() # 1st residual block self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads) - self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) # 2nd residual block self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim) - self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) def forward(self, hidden_states, attention_mask): # 1st residual block @@ -1073,16 +1064,16 @@ class ProphetNetDecoderLayer(nn.Module): super().__init__() # 1st residual block self.self_attn = ProphetNetNgramProphetNetSelfAttention(config) - self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) # 2nd residual block if config.add_cross_attention: self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads) - self.cross_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size) + self.cross_attn_layer_norm = LayerNorm(config.hidden_size) # 3rd residual block self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim) - self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) def forward( self, @@ -1154,7 +1145,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) ) self.position_embeddings = ProhpetNetPositionalEmbeddings(config) - self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) @@ -1274,7 +1265,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) - self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.init_weights()