[XLNet] Use pytorch's layernorm like in BERT
See #1089 cc @thomwolf @lysandrejik Also @dhpollack
This commit is contained in:
@@ -337,20 +337,7 @@ try:
|
|||||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
|
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
||||||
class XLNetLayerNorm(nn.Module):
|
from torch.nn import LayerNorm as XLNetLayerNorm
|
||||||
def __init__(self, d_model, eps=1e-12):
|
|
||||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
|
||||||
"""
|
|
||||||
super(XLNetLayerNorm, self).__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(d_model))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(d_model))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
u = x.mean(-1, keepdim=True)
|
|
||||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
|
||||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
|
||||||
return self.weight * x + self.bias
|
|
||||||
|
|
||||||
class XLNetRelativeAttention(nn.Module):
|
class XLNetRelativeAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|||||||
Reference in New Issue
Block a user