diff --git a/docs/source/en/model_doc/gpt2.md b/docs/source/en/model_doc/gpt2.md index b2afbbd3b2..89a0429cca 100644 --- a/docs/source/en/model_doc/gpt2.md +++ b/docs/source/en/model_doc/gpt2.md @@ -127,6 +127,64 @@ Below is an expected speedup diagram that compares pure inference time between t + +## Using Scaled Dot Product Attention (SDPA) +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="sdpa") +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04) using `float16` with +[gpt2-large](https://huggingface.co/openai-community/gpt2-large), we saw the +following speedups during training and inference. + +### Training +| Batch size | Seq len | Time per batch (Eager - s) | Time per batch (SDPA - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) | +|-----------:|--------:|----------------------------:|--------------------------:|------------:|--------------------:|-------------------:|------------------:| +| 1 | 128 | 0.039 | 0.032 | 23.042 | 3482.32 | 3494.62 | -0.352 | +| 1 | 256 | 0.073 | 0.059 | 25.15 | 3546.66 | 3552.6 | -0.167 | +| 1 | 512 | 0.155 | 0.118 | 30.96 | 4230.1 | 3665.59 | 15.4 | +| 1 | 1024 | 0.316 | 0.209 | 50.839 | 8682.26 | 4881.09 | 77.875 | +| 2 | 128 | 0.07 | 0.06 | 15.324 | 3557.8 | 3545.91 | 0.335 | +| 2 | 256 | 0.143 | 0.122 | 16.53 | 3901.5 | 3657.68 | 6.666 | +| 2 | 512 | 0.267 | 0.213 | 25.626 | 7062.21 | 4876.47 | 44.822 | +| 2 | 1024 | OOM | 0.404 | / | OOM | 8096.35 | SDPA does not OOM | +| 4 | 128 | 0.134 | 0.128 | 4.412 | 3675.79 | 3648.72 | 0.742 | +| 4 | 256 | 0.243 | 0.217 | 12.292 | 6129.76 | 4871.12 | 25.839 | +| 4 | 512 | 0.494 | 0.406 | 21.687 | 12466.6 | 8102.64 | 53.858 | +| 4 | 1024 | OOM | 0.795 | / | OOM | 14568.2 | SDPA does not OOM | + +### Inference +| Batch size | Seq len | Per token latency Eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem Eager (MB) | Mem SDPA (MB) | Mem saved (%) | +|-----------:|--------:|-----------------------------:|----------------------------:|------------:|---------------:|--------------:|--------------:| +| 1 | 128 | 7.991 | 6.968 | 14.681 | 1685.2 | 1701.32 | -0.947 | +| 1 | 256 | 8.462 | 7.199 | 17.536 | 1745.49 | 1770.78 | -1.428 | +| 1 | 512 | 8.68 | 7.853 | 10.529 | 1907.69 | 1921.29 | -0.708 | +| 1 | 768 | 9.101 | 8.365 | 8.791 | 2032.93 | 2068.12 | -1.701 | +| 2 | 128 | 9.169 | 9.001 | 1.861 | 1803.84 | 1811.4 | -0.418 | +| 2 | 256 | 9.907 | 9.78 | 1.294 | 1907.72 | 1921.44 | -0.714 | +| 2 | 512 | 11.519 | 11.644 | -1.071 | 2176.86 | 2197.75 | -0.951 | +| 2 | 768 | 13.022 | 13.407 | -2.873 | 2464.3 | 2491.06 | -1.074 | +| 4 | 128 | 10.097 | 9.831 | 2.709 | 1942.25 | 1985.13 | -2.16 | +| 4 | 256 | 11.599 | 11.398 | 1.764 | 2177.28 | 2197.86 | -0.937 | +| 4 | 512 | 14.653 | 14.45 | 1.411 | 2753.16 | 2772.57 | -0.7 | +| 4 | 768 | 17.846 | 17.617 | 1.299 | 3327.04 | 3343.97 | -0.506 | + + + + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index a81fc49381..fb443a3ce1 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -201,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) +* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 7c62d06949..ed9b10bbdb 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -24,10 +24,12 @@ from typing import Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -42,6 +44,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, @@ -557,6 +560,113 @@ class GPT2FlashAttention2(GPT2Attention): ) +class GPT2SdpaAttention(GPT2Attention): + """ + GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass + to adapt to the SDPA API. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__ + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + bsz, q_len, _ = hidden_states.size() + + # Initial attention projections + is_cross_attention = encoder_hidden_states is not None + if is_cross_attention: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + # Optional kv caching + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA + if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, + ) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.embed_dim) + + # Final projection + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + return attn_output, present, None + + class GPT2MLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() @@ -574,10 +684,7 @@ class GPT2MLP(nn.Module): return hidden_states -GPT2_ATTENTION_CLASSES = { - "eager": GPT2Attention, - "flash_attention_2": GPT2FlashAttention2, -} +GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention} class GPT2Block(nn.Module): @@ -673,6 +780,7 @@ class GPT2PreTrainedModel(PreTrainedModel): _no_split_modules = ["GPT2Block"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -1021,11 +1129,24 @@ class GPT2Model(GPT2PreTrainedModel): position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + # Attention mask. + _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) if self._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None + elif _use_sdpa: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(batch_size, input_shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) else: # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] @@ -1049,7 +1170,11 @@ class GPT2Model(GPT2PreTrainedModel): encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - if self._attn_implementation != "flash_attention_2": + if _use_sdpa: + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + elif not self._attn_implementation == "flash_attention_2": encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_attention_mask = None @@ -1060,11 +1185,6 @@ class GPT2Model(GPT2PreTrainedModel): # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index cde28cbc58..5755658288 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -832,7 +832,8 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): start = datetime.datetime.now() model.generate(input_ids, do_sample=False, max_time=None, max_length=256) duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) @slow def test_contrastive_search_gpt2(self):