Merge pull request #962 from guotong1988/patch-1

Update modeling_xlnet.py
This commit is contained in:
Thomas Wolf
2019-08-07 10:09:20 +02:00
committed by GitHub

View File

@@ -335,7 +335,7 @@ class XLNetConfig(PretrainedConfig):
try: try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
except ImportError: except (ImportError, AttributeError) as e:
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
class XLNetLayerNorm(nn.Module): class XLNetLayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-12): def __init__(self, d_model, eps=1e-12):