Add sdpa support for Albert (#32092)
* Add sdpa support for Albert * [run_slow] albert * Add benchmarks and PR suggestion * Fix quality * Fix * [run_slow] albert
This commit is contained in:
@@ -59,7 +59,52 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). This
|
|||||||
- Layers are split in groups that share parameters (to save memory).
|
- Layers are split in groups that share parameters (to save memory).
|
||||||
Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not.
|
Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not.
|
||||||
|
|
||||||
|
### Using Scaled Dot Product Attention (SDPA)
|
||||||
|
|
||||||
|
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||||
|
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||||
|
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||||
|
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||||
|
page for more information.
|
||||||
|
|
||||||
|
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||||
|
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||||
|
|
||||||
|
```
|
||||||
|
from transformers import AlbertModel
|
||||||
|
model = AlbertModel.from_pretrained("albert/albert-base-v1", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||||
|
|
||||||
|
On a local benchmark (GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16`, we saw the
|
||||||
|
following speedups during training and inference.
|
||||||
|
|
||||||
|
#### Training for 100 iterations
|
||||||
|
|
||||||
|
|batch_size|seq_len|Time per batch (eager - s)| Time per batch (sdpa - s)| Speedup (%)| Eager peak mem (MB)| sdpa peak mem (MB)| Mem saving (%)|
|
||||||
|
|----------|-------|--------------------------|--------------------------|------------|--------------------|-------------------|---------------|
|
||||||
|
|2 |256 |0.028 |0.024 |14.388 |358.411 |321.088 |11.624 |
|
||||||
|
|2 |512 |0.049 |0.041 |17.681 |753.458 |602.660 |25.022 |
|
||||||
|
|4 |256 |0.044 |0.039 |12.246 |679.534 |602.660 |12.756 |
|
||||||
|
|4 |512 |0.090 |0.076 |18.472 |1434.820 |1134.140 |26.512 |
|
||||||
|
|8 |256 |0.081 |0.072 |12.664 |1283.825 |1134.140 |13.198 |
|
||||||
|
|8 |512 |0.170 |0.143 |18.957 |2820.398 |2219.695 |27.062 |
|
||||||
|
|
||||||
|
#### Inference with 50 batches
|
||||||
|
|
||||||
|
|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%) |Mem eager (MB)|Mem BT (MB)|Mem saved (%)|
|
||||||
|
|----------|-------|----------------------------|---------------------------|------------|--------------|-----------|-------------|
|
||||||
|
|4 |128 |0.083 |0.071 |16.967 |48.319 |48.45 |-0.268 |
|
||||||
|
|4 |256 |0.148 |0.127 |16.37 |63.4 |63.922 |-0.817 |
|
||||||
|
|4 |512 |0.31 |0.247 |25.473 |110.092 |94.343 |16.693 |
|
||||||
|
|8 |128 |0.137 |0.124 |11.102 |63.4 |63.66 |-0.409 |
|
||||||
|
|8 |256 |0.271 |0.231 |17.271 |91.202 |92.246 |-1.132 |
|
||||||
|
|8 |512 |0.602 |0.48 |25.47 |186.159 |152.564 |22.021 |
|
||||||
|
|16 |128 |0.252 |0.224 |12.506 |91.202 |91.722 |-0.567 |
|
||||||
|
|16 |256 |0.526 |0.448 |17.604 |148.378 |150.467 |-1.388 |
|
||||||
|
|16 |512 |1.203 |0.96 |25.365 |338.293 |271.102 |24.784 |
|
||||||
|
|
||||||
This model was contributed by [lysandre](https://huggingface.co/lysandre). This model jax version was contributed by
|
This model was contributed by [lysandre](https://huggingface.co/lysandre). This model jax version was contributed by
|
||||||
[kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/google-research/ALBERT).
|
[kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/google-research/ALBERT).
|
||||||
|
|||||||
@@ -201,6 +201,7 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
|
|||||||
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||||
|
|
||||||
For now, Transformers supports SDPA inference and training for the following architectures:
|
For now, Transformers supports SDPA inference and training for the following architectures:
|
||||||
|
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
|
||||||
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
||||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||||
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPooling,
|
BaseModelOutputWithPooling,
|
||||||
@@ -34,7 +35,12 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import (
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
is_torch_greater_or_equal_than_2_2,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -358,6 +364,66 @@ class AlbertAttention(nn.Module):
|
|||||||
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
|
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
|
||||||
|
|
||||||
|
|
||||||
|
class AlbertSdpaAttention(AlbertAttention):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.dropout_prob = config.attention_probs_dropout_prob
|
||||||
|
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||||
|
logger.warning(
|
||||||
|
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||||
|
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||||
|
"the eager attention implementation, but specifying the eager implementation will be required from "
|
||||||
|
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||||
|
'`attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(hidden_states, attention_mask, head_mask, output_attentions)
|
||||||
|
|
||||||
|
batch_size, seq_len, _ = hidden_states.size()
|
||||||
|
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||||
|
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||||
|
query_layer = query_layer.contiguous()
|
||||||
|
key_layer = key_layer.contiguous()
|
||||||
|
value_layer = value_layer.contiguous()
|
||||||
|
|
||||||
|
attention_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query=query_layer,
|
||||||
|
key=key_layer,
|
||||||
|
value=value_layer,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_output = attention_output.transpose(1, 2)
|
||||||
|
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
|
||||||
|
|
||||||
|
projected_context_layer = self.dense(attention_output)
|
||||||
|
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
||||||
|
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
||||||
|
return (layernormed_context_layer,)
|
||||||
|
|
||||||
|
|
||||||
|
ALBERT_ATTENTION_CLASSES = {
|
||||||
|
"eager": AlbertAttention,
|
||||||
|
"sdpa": AlbertSdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AlbertLayer(nn.Module):
|
class AlbertLayer(nn.Module):
|
||||||
def __init__(self, config: AlbertConfig):
|
def __init__(self, config: AlbertConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -366,7 +432,7 @@ class AlbertLayer(nn.Module):
|
|||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
self.seq_len_dim = 1
|
self.seq_len_dim = 1
|
||||||
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attention = AlbertAttention(config)
|
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||||
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
self.activation = ACT2FN[config.hidden_act]
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
@@ -496,6 +562,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = AlbertConfig
|
config_class = AlbertConfig
|
||||||
load_tf_weights = load_tf_weights_in_albert
|
load_tf_weights = load_tf_weights_in_albert
|
||||||
base_model_prefix = "albert"
|
base_model_prefix = "albert"
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
@@ -635,6 +702,9 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
self.pooler = None
|
self.pooler = None
|
||||||
self.pooler_activation = None
|
self.pooler_activation = None
|
||||||
|
|
||||||
|
self.attn_implementation = config._attn_implementation
|
||||||
|
self.position_embedding_type = config.position_embedding_type
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@@ -708,14 +778,28 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_sdpa_attention_mask = (
|
||||||
|
self.attn_implementation == "sdpa"
|
||||||
|
and self.position_embedding_type == "absolute"
|
||||||
|
and head_mask is None
|
||||||
|
and not output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_sdpa_attention_mask:
|
||||||
|
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||||
|
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
extended_attention_mask,
|
extended_attention_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user