torch.cuda.is_available() is redundant as apex handles that internally (#9350)
This commit is contained in:
@@ -110,13 +110,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
|
|
||||||
def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
|
def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
|
||||||
if torch.cuda.is_available():
|
try:
|
||||||
try:
|
from apex.normalization import FusedLayerNorm
|
||||||
from apex.normalization import FusedLayerNorm
|
|
||||||
|
|
||||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -265,14 +265,12 @@ FSMT_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
|
|
||||||
have_fused_layer_norm = False
|
have_fused_layer_norm = False
|
||||||
if torch.cuda.is_available():
|
try:
|
||||||
try:
|
from apex.normalization import FusedLayerNorm
|
||||||
from apex.normalization import FusedLayerNorm
|
|
||||||
|
|
||||||
have_fused_layer_norm = True
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
have_fused_layer_norm = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm
|
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -511,13 +511,12 @@ class ProphetNetDecoderLMOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
||||||
if torch.cuda.is_available():
|
try:
|
||||||
try:
|
from apex.normalization import FusedLayerNorm
|
||||||
from apex.normalization import FusedLayerNorm
|
|
||||||
|
|
||||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user