Add sdpa for DistilBert (#33724)
* Add sdpa for DistilBert * [run_slow] distilbert * [run_slow] distilbert * [run_slow] distilbert * Try without slow tests * [run_slow] distilbert * [run_slow] distilbert
This commit is contained in:
@@ -66,6 +66,53 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
|
||||
* predicting the masked tokens correctly (but no next-sentence objective)
|
||||
* a cosine similarity between the hidden states of the student and the teacher model
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
from transformers import DistilBertModel
|
||||
model = DistilBertModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and the `distilbert-base-uncased` model with
|
||||
a MaskedLM head, we saw the following speedups during training and inference.
|
||||
|
||||
#### Training
|
||||
|
||||
| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|
||||
|--------------------|------------|---------|---------|----------------------------|---------------------------|-------------|---------------------|--------------------|----------------|
|
||||
| 100 | 1 | 128 | False | 0.010 | 0.008 | 28.870 | 397.038 | 399.629 | -0.649 |
|
||||
| 100 | 1 | 256 | False | 0.011 | 0.009 | 20.681 | 412.505 | 412.606 | -0.025 |
|
||||
| 100 | 2 | 128 | False | 0.011 | 0.009 | 23.741 | 412.213 | 412.606 | -0.095 |
|
||||
| 100 | 2 | 256 | False | 0.015 | 0.013 | 16.502 | 427.491 | 425.787 | 0.400 |
|
||||
| 100 | 4 | 128 | False | 0.015 | 0.013 | 13.828 | 427.491 | 425.787 | 0.400 |
|
||||
| 100 | 4 | 256 | False | 0.025 | 0.022 | 12.882 | 594.156 | 502.745 | 18.182 |
|
||||
| 100 | 8 | 128 | False | 0.023 | 0.022 | 8.010 | 545.922 | 502.745 | 8.588 |
|
||||
| 100 | 8 | 256 | False | 0.046 | 0.041 | 12.763 | 983.450 | 798.480 | 23.165 |
|
||||
|
||||
#### Inference
|
||||
|
||||
| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|
||||
|-------------|------------|---------|---------|---------|----------|-----------------------------|-----------------------------|-------------|----------------|--------------|---------------|
|
||||
| 50 | 2 | 64 | True | True | True | 0.032 | 0.025 | 28.192 | 154.532 | 155.531 | -0.642 |
|
||||
| 50 | 2 | 128 | True | True | True | 0.033 | 0.025 | 32.636 | 157.286 | 157.482 | -0.125 |
|
||||
| 50 | 4 | 64 | True | True | True | 0.032 | 0.026 | 24.783 | 157.023 | 157.449 | -0.271 |
|
||||
| 50 | 4 | 128 | True | True | True | 0.034 | 0.028 | 19.299 | 162.794 | 162.269 | 0.323 |
|
||||
| 50 | 8 | 64 | True | True | True | 0.035 | 0.028 | 25.105 | 160.958 | 162.204 | -0.768 |
|
||||
| 50 | 8 | 128 | True | True | True | 0.052 | 0.046 | 12.375 | 173.155 | 171.844 | 0.763 |
|
||||
| 50 | 16 | 64 | True | True | True | 0.051 | 0.045 | 12.882 | 172.106 | 171.713 | 0.229 |
|
||||
| 50 | 16 | 128 | True | True | True | 0.096 | 0.081 | 18.524 | 191.257 | 191.517 | -0.136 |
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
|
||||
@@ -219,6 +219,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
||||
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
|
||||
@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import get_activation
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
@@ -38,7 +39,12 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
is_torch_greater_or_equal_than_2_2,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@@ -329,6 +335,86 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
|
||||
return (attn_output,)
|
||||
|
||||
|
||||
class DistilBertSdpaAttention(MultiHeadSelfAttention):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config=config)
|
||||
self.dropout_prob = config.attention_dropout
|
||||
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
query: torch.tensor(bs, seq_length, dim)
|
||||
key: torch.tensor(bs, seq_length, dim)
|
||||
value: torch.tensor(bs, seq_length, dim)
|
||||
mask: torch.tensor(bs, seq_length)
|
||||
|
||||
Returns:
|
||||
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
|
||||
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
|
||||
"""
|
||||
if output_attentions or head_mask is not None:
|
||||
logger.warning_once(
|
||||
"DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
|
||||
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
|
||||
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
|
||||
' removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
batch_size, _, _ = query.size()
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
|
||||
def shape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""separate heads"""
|
||||
return x.view(batch_size, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
||||
|
||||
def unshape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""group heads"""
|
||||
return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * dim_per_head)
|
||||
|
||||
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
|
||||
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
|
||||
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None:
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
attn_output = unshape(attn_output)
|
||||
attn_output = self.out_lin(attn_output)
|
||||
|
||||
return (attn_output,)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
@@ -353,6 +439,7 @@ class FFN(nn.Module):
|
||||
DISTILBERT_ATTENTION_CLASSES = {
|
||||
"eager": MultiHeadSelfAttention,
|
||||
"flash_attention_2": DistilBertFlashAttention2,
|
||||
"sdpa": DistilBertSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
@@ -503,6 +590,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "distilbert"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
@@ -589,6 +677,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
self.embeddings = Embeddings(config) # Embeddings
|
||||
self.transformer = Transformer(config) # Encoder
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@@ -689,6 +778,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
head_mask_is_none = head_mask is None
|
||||
# Prepare head mask if needed
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
@@ -700,6 +790,11 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||
|
||||
if self._use_sdpa and head_mask_is_none and not output_attentions:
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embeddings.dtype, tgt_len=input_shape[1]
|
||||
)
|
||||
|
||||
return self.transformer(
|
||||
x=embeddings,
|
||||
attn_mask=attention_mask,
|
||||
|
||||
Reference in New Issue
Block a user