Add sdpa for DistilBert (#33724)

* Add sdpa for DistilBert

* [run_slow] distilbert

* [run_slow] distilbert

* [run_slow] distilbert

* Try without slow tests

* [run_slow] distilbert

* [run_slow] distilbert
This commit is contained in:
Omar Salman
2024-10-02 17:55:19 +05:00
committed by GitHub
parent 614c79a9b0
commit e7c8af7f33
3 changed files with 144 additions and 1 deletions

View File

@@ -66,6 +66,53 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
* predicting the masked tokens correctly (but no next-sentence objective) * predicting the masked tokens correctly (but no next-sentence objective)
* a cosine similarity between the hidden states of the student and the teacher model * a cosine similarity between the hidden states of the student and the teacher model
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import DistilBertModel
model = DistilBertModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and the `distilbert-base-uncased` model with
a MaskedLM head, we saw the following speedups during training and inference.
#### Training
| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|--------------------|------------|---------|---------|----------------------------|---------------------------|-------------|---------------------|--------------------|----------------|
| 100 | 1 | 128 | False | 0.010 | 0.008 | 28.870 | 397.038 | 399.629 | -0.649 |
| 100 | 1 | 256 | False | 0.011 | 0.009 | 20.681 | 412.505 | 412.606 | -0.025 |
| 100 | 2 | 128 | False | 0.011 | 0.009 | 23.741 | 412.213 | 412.606 | -0.095 |
| 100 | 2 | 256 | False | 0.015 | 0.013 | 16.502 | 427.491 | 425.787 | 0.400 |
| 100 | 4 | 128 | False | 0.015 | 0.013 | 13.828 | 427.491 | 425.787 | 0.400 |
| 100 | 4 | 256 | False | 0.025 | 0.022 | 12.882 | 594.156 | 502.745 | 18.182 |
| 100 | 8 | 128 | False | 0.023 | 0.022 | 8.010 | 545.922 | 502.745 | 8.588 |
| 100 | 8 | 256 | False | 0.046 | 0.041 | 12.763 | 983.450 | 798.480 | 23.165 |
#### Inference
| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|-------------|------------|---------|---------|---------|----------|-----------------------------|-----------------------------|-------------|----------------|--------------|---------------|
| 50 | 2 | 64 | True | True | True | 0.032 | 0.025 | 28.192 | 154.532 | 155.531 | -0.642 |
| 50 | 2 | 128 | True | True | True | 0.033 | 0.025 | 32.636 | 157.286 | 157.482 | -0.125 |
| 50 | 4 | 64 | True | True | True | 0.032 | 0.026 | 24.783 | 157.023 | 157.449 | -0.271 |
| 50 | 4 | 128 | True | True | True | 0.034 | 0.028 | 19.299 | 162.794 | 162.269 | 0.323 |
| 50 | 8 | 64 | True | True | True | 0.035 | 0.028 | 25.105 | 160.958 | 162.204 | -0.768 |
| 50 | 8 | 128 | True | True | True | 0.052 | 0.046 | 12.375 | 173.155 | 171.844 | 0.763 |
| 50 | 16 | 64 | True | True | True | 0.051 | 0.045 | 12.882 | 172.106 | 171.713 | 0.229 |
| 50 | 16 | 128 | True | True | True | 0.096 | 0.081 | 18.524 | 191.257 | 191.517 | -0.136 |
## Resources ## Resources

View File

@@ -219,6 +219,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2) * [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)

View File

@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import get_activation from ...activations import get_activation
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
@@ -38,7 +39,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_2_2,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
@@ -329,6 +335,86 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
return (attn_output,) return (attn_output,)
class DistilBertSdpaAttention(MultiHeadSelfAttention):
def __init__(self, config: PretrainedConfig):
super().__init__(config=config)
self.dropout_prob = config.attention_dropout
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
"""
if output_attentions or head_mask is not None:
logger.warning_once(
"DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
' removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
query,
key,
value,
mask,
head_mask,
output_attentions,
)
batch_size, _, _ = query.size()
dim_per_head = self.dim // self.n_heads
def shape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(batch_size, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x: torch.Tensor) -> torch.Tensor:
"""group heads"""
return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
)
attn_output = unshape(attn_output)
attn_output = self.out_lin(attn_output)
return (attn_output,)
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
@@ -353,6 +439,7 @@ class FFN(nn.Module):
DISTILBERT_ATTENTION_CLASSES = { DISTILBERT_ATTENTION_CLASSES = {
"eager": MultiHeadSelfAttention, "eager": MultiHeadSelfAttention,
"flash_attention_2": DistilBertFlashAttention2, "flash_attention_2": DistilBertFlashAttention2,
"sdpa": DistilBertSdpaAttention,
} }
@@ -503,6 +590,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
"""Initialize the weights.""" """Initialize the weights."""
@@ -589,6 +677,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings = Embeddings(config) # Embeddings self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder self.transformer = Transformer(config) # Encoder
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -689,6 +778,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
head_mask_is_none = head_mask is None
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
@@ -700,6 +790,11 @@ class DistilBertModel(DistilBertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
if self._use_sdpa and head_mask_is_none and not output_attentions:
attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embeddings.dtype, tgt_len=input_shape[1]
)
return self.transformer( return self.transformer(
x=embeddings, x=embeddings,
attn_mask=attention_mask, attn_mask=attention_mask,