[t5/t0/mt5 models] faster/leaner custom layer norm (#14656)
* [t5] faster/leaner custom layer norm * wip * apex.normalization.FusedRMSNorm * cleanup * cleanup * add doc * add catch all * Trigger CI * expand
This commit is contained in:
@@ -263,6 +263,11 @@ print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
|
|||||||
|
|
||||||
<a id='scripts'></a>
|
<a id='scripts'></a>
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
If you'd like a faster training and inference performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter.
|
||||||
|
|
||||||
|
|
||||||
## Example scripts
|
## Example scripts
|
||||||
|
|
||||||
T5 is supported by several example scripts, both for pre-training and fine-tuning.
|
T5 is supported by several example scripts, both for pre-training and fine-tuning.
|
||||||
|
|||||||
@@ -237,14 +237,19 @@ DEPARALLELIZE_DOCSTRING = r"""
|
|||||||
class T5LayerNorm(nn.Module):
|
class T5LayerNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
|
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# layer norm should always be calculated in float32
|
|
||||||
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
||||||
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
|
||||||
@@ -255,6 +260,20 @@ class T5LayerNorm(nn.Module):
|
|||||||
return self.weight * hidden_states
|
return self.weight * hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization import FusedRMSNorm
|
||||||
|
|
||||||
|
T5LayerNorm = FusedRMSNorm # noqa
|
||||||
|
|
||||||
|
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
|
||||||
|
except ImportError:
|
||||||
|
# using the normal T5LayerNorm
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class T5DenseReluDense(nn.Module):
|
class T5DenseReluDense(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
Reference in New Issue
Block a user