[GPT2] Add SDPA support (#31172)
* `gpt2` sdpa support * fix (at least) one test, style, repo consistency * fix sdpa mask in forward --> fixes generation * test * test2 * test3 * test4 * simplify shapes for attn mask creation and small comments * hub fail test * benchmarks * flash attn 2 mask should not be inverted on enc-dec setup * fix comment * apply some suggestion from code review - only save _attn_implentation once - remove unnecessary comment * change elif logic * [run-slow] gpt2 * modify `test_gpt2_sample_max_time` to follow previous assertion patterns
This commit is contained in:
@@ -127,6 +127,64 @@ Below is an expected speedup diagram that compares pure inference time between t
|
|||||||
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
|
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
## Using Scaled Dot Product Attention (SDPA)
|
||||||
|
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||||
|
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||||
|
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||||
|
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||||
|
page for more information.
|
||||||
|
|
||||||
|
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||||
|
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||||
|
|
||||||
|
On a local benchmark (rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04) using `float16` with
|
||||||
|
[gpt2-large](https://huggingface.co/openai-community/gpt2-large), we saw the
|
||||||
|
following speedups during training and inference.
|
||||||
|
|
||||||
|
### Training
|
||||||
|
| Batch size | Seq len | Time per batch (Eager - s) | Time per batch (SDPA - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
|
||||||
|
|-----------:|--------:|----------------------------:|--------------------------:|------------:|--------------------:|-------------------:|------------------:|
|
||||||
|
| 1 | 128 | 0.039 | 0.032 | 23.042 | 3482.32 | 3494.62 | -0.352 |
|
||||||
|
| 1 | 256 | 0.073 | 0.059 | 25.15 | 3546.66 | 3552.6 | -0.167 |
|
||||||
|
| 1 | 512 | 0.155 | 0.118 | 30.96 | 4230.1 | 3665.59 | 15.4 |
|
||||||
|
| 1 | 1024 | 0.316 | 0.209 | 50.839 | 8682.26 | 4881.09 | 77.875 |
|
||||||
|
| 2 | 128 | 0.07 | 0.06 | 15.324 | 3557.8 | 3545.91 | 0.335 |
|
||||||
|
| 2 | 256 | 0.143 | 0.122 | 16.53 | 3901.5 | 3657.68 | 6.666 |
|
||||||
|
| 2 | 512 | 0.267 | 0.213 | 25.626 | 7062.21 | 4876.47 | 44.822 |
|
||||||
|
| 2 | 1024 | OOM | 0.404 | / | OOM | 8096.35 | SDPA does not OOM |
|
||||||
|
| 4 | 128 | 0.134 | 0.128 | 4.412 | 3675.79 | 3648.72 | 0.742 |
|
||||||
|
| 4 | 256 | 0.243 | 0.217 | 12.292 | 6129.76 | 4871.12 | 25.839 |
|
||||||
|
| 4 | 512 | 0.494 | 0.406 | 21.687 | 12466.6 | 8102.64 | 53.858 |
|
||||||
|
| 4 | 1024 | OOM | 0.795 | / | OOM | 14568.2 | SDPA does not OOM |
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
| Batch size | Seq len | Per token latency Eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem Eager (MB) | Mem SDPA (MB) | Mem saved (%) |
|
||||||
|
|-----------:|--------:|-----------------------------:|----------------------------:|------------:|---------------:|--------------:|--------------:|
|
||||||
|
| 1 | 128 | 7.991 | 6.968 | 14.681 | 1685.2 | 1701.32 | -0.947 |
|
||||||
|
| 1 | 256 | 8.462 | 7.199 | 17.536 | 1745.49 | 1770.78 | -1.428 |
|
||||||
|
| 1 | 512 | 8.68 | 7.853 | 10.529 | 1907.69 | 1921.29 | -0.708 |
|
||||||
|
| 1 | 768 | 9.101 | 8.365 | 8.791 | 2032.93 | 2068.12 | -1.701 |
|
||||||
|
| 2 | 128 | 9.169 | 9.001 | 1.861 | 1803.84 | 1811.4 | -0.418 |
|
||||||
|
| 2 | 256 | 9.907 | 9.78 | 1.294 | 1907.72 | 1921.44 | -0.714 |
|
||||||
|
| 2 | 512 | 11.519 | 11.644 | -1.071 | 2176.86 | 2197.75 | -0.951 |
|
||||||
|
| 2 | 768 | 13.022 | 13.407 | -2.873 | 2464.3 | 2491.06 | -1.074 |
|
||||||
|
| 4 | 128 | 10.097 | 9.831 | 2.709 | 1942.25 | 1985.13 | -2.16 |
|
||||||
|
| 4 | 256 | 11.599 | 11.398 | 1.764 | 2177.28 | 2197.86 | -0.937 |
|
||||||
|
| 4 | 512 | 14.653 | 14.45 | 1.411 | 2753.16 | 2772.57 | -0.7 |
|
||||||
|
| 4 | 768 | 17.846 | 17.617 | 1.299 | 3327.04 | 3343.97 | -0.506 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||||
|
|||||||
@@ -201,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||||
|
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||||
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
||||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||||
|
|||||||
@@ -24,10 +24,12 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
@@ -42,6 +44,7 @@ from ...utils import (
|
|||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
get_torch_version,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
@@ -557,6 +560,113 @@ class GPT2FlashAttention2(GPT2Attention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2SdpaAttention(GPT2Attention):
|
||||||
|
"""
|
||||||
|
GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||||
|
`GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
|
||||||
|
to adapt to the SDPA API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
|
||||||
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||||
|
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||||
|
if output_attentions or head_mask is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||||
|
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||||
|
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||||
|
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# Initial attention projections
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
if is_cross_attention:
|
||||||
|
if not hasattr(self, "q_attn"):
|
||||||
|
raise ValueError(
|
||||||
|
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||||
|
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.q_attn(hidden_states)
|
||||||
|
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
else:
|
||||||
|
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||||
|
|
||||||
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
# Optional kv caching
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key = layer_past[0]
|
||||||
|
past_value = layer_past[1]
|
||||||
|
key = torch.cat((past_key, key), dim=-2)
|
||||||
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
|
|
||||||
|
present = None
|
||||||
|
if use_cache is True:
|
||||||
|
present = (key, value)
|
||||||
|
|
||||||
|
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
|
||||||
|
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
|
||||||
|
query = query.contiguous()
|
||||||
|
key = key.contiguous()
|
||||||
|
value = value.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.attn_dropout.p if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape outputs
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(bsz, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
# Final projection
|
||||||
|
attn_output = self.c_proj(attn_output)
|
||||||
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
|
return attn_output, present, None
|
||||||
|
|
||||||
|
|
||||||
class GPT2MLP(nn.Module):
|
class GPT2MLP(nn.Module):
|
||||||
def __init__(self, intermediate_size, config):
|
def __init__(self, intermediate_size, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -574,10 +684,7 @@ class GPT2MLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
GPT2_ATTENTION_CLASSES = {
|
GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
|
||||||
"eager": GPT2Attention,
|
|
||||||
"flash_attention_2": GPT2FlashAttention2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GPT2Block(nn.Module):
|
class GPT2Block(nn.Module):
|
||||||
@@ -673,6 +780,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["GPT2Block"]
|
_no_split_modules = ["GPT2Block"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
@@ -1021,11 +1129,24 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
|
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
if self._attn_implementation == "flash_attention_2":
|
if self._attn_implementation == "flash_attention_2":
|
||||||
attention_mask = attention_mask if 0 in attention_mask else None
|
attention_mask = attention_mask if 0 in attention_mask else None
|
||||||
|
elif _use_sdpa:
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
input_shape=(batch_size, input_shape[-1]),
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_length,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
@@ -1049,7 +1170,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
if self._attn_implementation != "flash_attention_2":
|
if _use_sdpa:
|
||||||
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||||
|
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||||
|
)
|
||||||
|
elif not self._attn_implementation == "flash_attention_2":
|
||||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
else:
|
else:
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
@@ -1060,11 +1185,6 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
position_embeds = self.wpe(position_ids)
|
|
||||||
hidden_states = inputs_embeds + position_embeds
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
hidden_states = hidden_states + token_type_embeds
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|||||||
@@ -832,7 +832,8 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
start = datetime.datetime.now()
|
start = datetime.datetime.now()
|
||||||
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
|
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
|
||||||
duration = datetime.datetime.now() - start
|
duration = datetime.datetime.now() - start
|
||||||
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||||
|
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_contrastive_search_gpt2(self):
|
def test_contrastive_search_gpt2(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user