[t5/t0/mt5 models] faster/leaner custom layer norm (#14656)
* [t5] faster/leaner custom layer norm * wip * apex.normalization.FusedRMSNorm * cleanup * cleanup * add doc * add catch all * Trigger CI * expand
This commit is contained in:
@@ -263,6 +263,11 @@ print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
|
||||
|
||||
<a id='scripts'></a>
|
||||
|
||||
## Performance
|
||||
|
||||
If you'd like a faster training and inference performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter.
|
||||
|
||||
|
||||
## Example scripts
|
||||
|
||||
T5 is supported by several example scripts, both for pre-training and fine-tuning.
|
||||
|
||||
@@ -237,14 +237,19 @@ DEPARALLELIZE_DOCSTRING = r"""
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# layer norm should always be calculated in float32
|
||||
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
@@ -255,6 +260,20 @@ class T5LayerNorm(nn.Module):
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm
|
||||
|
||||
T5LayerNorm = FusedRMSNorm # noqa
|
||||
|
||||
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
|
||||
except ImportError:
|
||||
# using the normal T5LayerNorm
|
||||
pass
|
||||
except Exception:
|
||||
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
|
||||
pass
|
||||
|
||||
|
||||
class T5DenseReluDense(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user