Merge pull request #962 from guotong1988/patch-1
Update modeling_xlnet.py
This commit is contained in:
@@ -335,7 +335,7 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
|
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
|
||||||
except ImportError:
|
except (ImportError, AttributeError) as e:
|
||||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
||||||
class XLNetLayerNorm(nn.Module):
|
class XLNetLayerNorm(nn.Module):
|
||||||
def __init__(self, d_model, eps=1e-12):
|
def __init__(self, d_model, eps=1e-12):
|
||||||
|
|||||||
Reference in New Issue
Block a user