From 3b0a14b7614d5a1741dbd5288f6ab1e38d4a3103 Mon Sep 17 00:00:00 2001 From: Deyu Fu Date: Wed, 12 Dec 2018 15:05:45 -0800 Subject: [PATCH] add fallback path for apex used in modeling.py --- pytorch_pretrained_bert/modeling.py | 49 +++++++++++++---------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 7b2b05e2e9..72c7d47f86 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -31,10 +31,6 @@ import shutil import torch from torch import nn from torch.nn import CrossEntropyLoss -try: - from apex.normalization.fused_layer_norm import FusedLayerNorm -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.") from .file_utils import cached_path @@ -157,22 +153,24 @@ class BertConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" +try: + from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm +except ImportError: + print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") + class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps -class BertLayerNorm(nn.Module): - def __init__(self, config, variance_epsilon=1e-12): - """Construct a layernorm module in the TF style (epsilon inside the square root). - """ - super(BertLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.ones(config.hidden_size)) - self.beta = nn.Parameter(torch.zeros(config.hidden_size)) - self.variance_epsilon = variance_epsilon - - def forward(self, x): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.gamma * x + self.beta - + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. @@ -185,7 +183,7 @@ class BertEmbeddings(nn.Module): # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file - self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): @@ -260,7 +258,7 @@ class BertSelfOutput(nn.Module): def __init__(self, config): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): @@ -299,7 +297,7 @@ class BertOutput(nn.Module): def __init__(self, config): super(BertOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): @@ -361,7 +359,7 @@ class BertPredictionHeadTransform(nn.Module): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.transform_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act - self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -443,12 +441,9 @@ class PreTrainedBertModel(nn.Module): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, FusedLayerNorm): + elif isinstance(module, BertLayerNorm): module.bias.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.beta.data.normal_(mean=0.0, std=self.config.initializer_range) - module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()