Add sdpa for DistilBert (#33724)
* Add sdpa for DistilBert * [run_slow] distilbert * [run_slow] distilbert * [run_slow] distilbert * Try without slow tests * [run_slow] distilbert * [run_slow] distilbert
This commit is contained in:
@@ -66,6 +66,53 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
|
|||||||
* predicting the masked tokens correctly (but no next-sentence objective)
|
* predicting the masked tokens correctly (but no next-sentence objective)
|
||||||
* a cosine similarity between the hidden states of the student and the teacher model
|
* a cosine similarity between the hidden states of the student and the teacher model
|
||||||
|
|
||||||
|
### Using Scaled Dot Product Attention (SDPA)
|
||||||
|
|
||||||
|
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||||
|
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||||
|
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||||
|
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||||
|
page for more information.
|
||||||
|
|
||||||
|
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||||
|
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||||
|
|
||||||
|
```
|
||||||
|
from transformers import DistilBertModel
|
||||||
|
model = DistilBertModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||||
|
```
|
||||||
|
|
||||||
|
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||||
|
|
||||||
|
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and the `distilbert-base-uncased` model with
|
||||||
|
a MaskedLM head, we saw the following speedups during training and inference.
|
||||||
|
|
||||||
|
#### Training
|
||||||
|
|
||||||
|
| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|
||||||
|
|--------------------|------------|---------|---------|----------------------------|---------------------------|-------------|---------------------|--------------------|----------------|
|
||||||
|
| 100 | 1 | 128 | False | 0.010 | 0.008 | 28.870 | 397.038 | 399.629 | -0.649 |
|
||||||
|
| 100 | 1 | 256 | False | 0.011 | 0.009 | 20.681 | 412.505 | 412.606 | -0.025 |
|
||||||
|
| 100 | 2 | 128 | False | 0.011 | 0.009 | 23.741 | 412.213 | 412.606 | -0.095 |
|
||||||
|
| 100 | 2 | 256 | False | 0.015 | 0.013 | 16.502 | 427.491 | 425.787 | 0.400 |
|
||||||
|
| 100 | 4 | 128 | False | 0.015 | 0.013 | 13.828 | 427.491 | 425.787 | 0.400 |
|
||||||
|
| 100 | 4 | 256 | False | 0.025 | 0.022 | 12.882 | 594.156 | 502.745 | 18.182 |
|
||||||
|
| 100 | 8 | 128 | False | 0.023 | 0.022 | 8.010 | 545.922 | 502.745 | 8.588 |
|
||||||
|
| 100 | 8 | 256 | False | 0.046 | 0.041 | 12.763 | 983.450 | 798.480 | 23.165 |
|
||||||
|
|
||||||
|
#### Inference
|
||||||
|
|
||||||
|
| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|
||||||
|
|-------------|------------|---------|---------|---------|----------|-----------------------------|-----------------------------|-------------|----------------|--------------|---------------|
|
||||||
|
| 50 | 2 | 64 | True | True | True | 0.032 | 0.025 | 28.192 | 154.532 | 155.531 | -0.642 |
|
||||||
|
| 50 | 2 | 128 | True | True | True | 0.033 | 0.025 | 32.636 | 157.286 | 157.482 | -0.125 |
|
||||||
|
| 50 | 4 | 64 | True | True | True | 0.032 | 0.026 | 24.783 | 157.023 | 157.449 | -0.271 |
|
||||||
|
| 50 | 4 | 128 | True | True | True | 0.034 | 0.028 | 19.299 | 162.794 | 162.269 | 0.323 |
|
||||||
|
| 50 | 8 | 64 | True | True | True | 0.035 | 0.028 | 25.105 | 160.958 | 162.204 | -0.768 |
|
||||||
|
| 50 | 8 | 128 | True | True | True | 0.052 | 0.046 | 12.375 | 173.155 | 171.844 | 0.763 |
|
||||||
|
| 50 | 16 | 64 | True | True | True | 0.051 | 0.045 | 12.882 | 172.106 | 171.713 | 0.229 |
|
||||||
|
| 50 | 16 | 128 | True | True | True | 0.096 | 0.081 | 18.524 | 191.257 | 191.517 | -0.136 |
|
||||||
|
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
|
|||||||
@@ -219,6 +219,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||||
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
||||||
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
|
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
|
||||||
|
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import get_activation
|
from ...activations import get_activation
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
@@ -38,7 +39,12 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import (
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
is_torch_greater_or_equal_than_2_2,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -329,6 +335,86 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
|
|||||||
return (attn_output,)
|
return (attn_output,)
|
||||||
|
|
||||||
|
|
||||||
|
class DistilBertSdpaAttention(MultiHeadSelfAttention):
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__(config=config)
|
||||||
|
self.dropout_prob = config.attention_dropout
|
||||||
|
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
query: torch.tensor(bs, seq_length, dim)
|
||||||
|
key: torch.tensor(bs, seq_length, dim)
|
||||||
|
value: torch.tensor(bs, seq_length, dim)
|
||||||
|
mask: torch.tensor(bs, seq_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
|
||||||
|
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
|
||||||
|
"""
|
||||||
|
if output_attentions or head_mask is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
|
||||||
|
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
|
||||||
|
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
|
||||||
|
' removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, _, _ = query.size()
|
||||||
|
dim_per_head = self.dim // self.n_heads
|
||||||
|
|
||||||
|
def shape(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""separate heads"""
|
||||||
|
return x.view(batch_size, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
||||||
|
|
||||||
|
def unshape(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""group heads"""
|
||||||
|
return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * dim_per_head)
|
||||||
|
|
||||||
|
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
|
||||||
|
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
|
||||||
|
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||||
|
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None:
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_mask=mask,
|
||||||
|
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = unshape(attn_output)
|
||||||
|
attn_output = self.out_lin(attn_output)
|
||||||
|
|
||||||
|
return (attn_output,)
|
||||||
|
|
||||||
|
|
||||||
class FFN(nn.Module):
|
class FFN(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(self, config: PretrainedConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -353,6 +439,7 @@ class FFN(nn.Module):
|
|||||||
DISTILBERT_ATTENTION_CLASSES = {
|
DISTILBERT_ATTENTION_CLASSES = {
|
||||||
"eager": MultiHeadSelfAttention,
|
"eager": MultiHeadSelfAttention,
|
||||||
"flash_attention_2": DistilBertFlashAttention2,
|
"flash_attention_2": DistilBertFlashAttention2,
|
||||||
|
"sdpa": DistilBertSdpaAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -503,6 +590,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "distilbert"
|
base_model_prefix = "distilbert"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module: nn.Module):
|
def _init_weights(self, module: nn.Module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
@@ -589,6 +677,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
self.embeddings = Embeddings(config) # Embeddings
|
self.embeddings = Embeddings(config) # Embeddings
|
||||||
self.transformer = Transformer(config) # Encoder
|
self.transformer = Transformer(config) # Encoder
|
||||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||||
|
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@@ -689,6 +778,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
head_mask_is_none = head_mask is None
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
@@ -700,6 +790,11 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||||
|
|
||||||
|
if self._use_sdpa and head_mask_is_none and not output_attentions:
|
||||||
|
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||||
|
attention_mask, embeddings.dtype, tgt_len=input_shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
return self.transformer(
|
return self.transformer(
|
||||||
x=embeddings,
|
x=embeddings,
|
||||||
attn_mask=attention_mask,
|
attn_mask=attention_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user