From bee361c6f1f7704f8c688895f2f86f6e5ff84727 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 15 Feb 2022 16:49:57 -0800 Subject: [PATCH] [t5/t0/mt5 models] faster/leaner custom layer norm (#14656) * [t5] faster/leaner custom layer norm * wip * apex.normalization.FusedRMSNorm * cleanup * cleanup * add doc * add catch all * Trigger CI * expand --- docs/source/model_doc/t5.mdx | 5 +++++ src/transformers/models/t5/modeling_t5.py | 23 +++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/docs/source/model_doc/t5.mdx b/docs/source/model_doc/t5.mdx index 47bcdc662f..dbcfaf1c7d 100644 --- a/docs/source/model_doc/t5.mdx +++ b/docs/source/model_doc/t5.mdx @@ -263,6 +263,11 @@ print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) +## Performance + +If you'd like a faster training and inference performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter. + + ## Example scripts T5 is supported by several example scripts, both for pre-training and fine-tuning. diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 3af2a53de2..0c211caccf 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -237,14 +237,19 @@ DEPARALLELIZE_DOCSTRING = r""" class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - Construct a layernorm module in the T5 style No bias and no subtraction of mean. + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): - # layer norm should always be calculated in float32 + + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) @@ -255,6 +260,20 @@ class T5LayerNorm(nn.Module): return self.weight * hidden_states +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm") +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + + class T5DenseReluDense(nn.Module): def __init__(self, config): super().__init__()