Add sdpa for BioGpt (#33592)
* Add sdpa for BioGpt * Updates * Add the docs * [run_slow] biogpt * Use the copy mechanism to ensure consistency * [run_slow] biogpt
This commit is contained in:
@@ -32,6 +32,51 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
|
|||||||
- BioGPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next token in a sequence. Leveraging this feature allows BioGPT to generate syntactically coherent text as it can be observed in the run_generation.py example script.
|
- BioGPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next token in a sequence. Leveraging this feature allows BioGPT to generate syntactically coherent text as it can be observed in the run_generation.py example script.
|
||||||
- The model can take the `past_key_values` (for PyTorch) as input, which is the previously computed key/value attention pairs. Using this (past_key_values or past) value prevents the model from re-computing pre-computed values in the context of text generation. For PyTorch, see past_key_values argument of the BioGptForCausalLM.forward() method for more information on its usage.
|
- The model can take the `past_key_values` (for PyTorch) as input, which is the previously computed key/value attention pairs. Using this (past_key_values or past) value prevents the model from re-computing pre-computed values in the context of text generation. For PyTorch, see past_key_values argument of the BioGptForCausalLM.forward() method for more information on its usage.
|
||||||
|
|
||||||
|
### Using Scaled Dot Product Attention (SDPA)
|
||||||
|
|
||||||
|
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||||
|
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||||
|
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||||
|
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||||
|
page for more information.
|
||||||
|
|
||||||
|
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||||
|
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||||
|
|
||||||
|
```
|
||||||
|
from transformers import BioGptForCausalLM
|
||||||
|
model = BioGptForCausalLM.from_pretrained("microsoft/biogpt", attn_implementation="sdpa", torch_dtype=torch.float16)
|
||||||
|
```
|
||||||
|
|
||||||
|
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and `microsoft/biogpt` model with a CausalLM head,
|
||||||
|
we saw the following speedups during training.
|
||||||
|
|
||||||
|
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||||
|
|
||||||
|
| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|
||||||
|
|--------------------|------------|---------|---------|----------------------------|---------------------------|-------------|---------------------|--------------------|----------------|
|
||||||
|
| 100 | 1 | 128 | False | 0.038 | 0.031 | 21.301 | 1601.862 | 1601.497 | 0.023 |
|
||||||
|
| 100 | 1 | 256 | False | 0.039 | 0.034 | 15.084 | 1624.944 | 1625.296 | -0.022 |
|
||||||
|
| 100 | 2 | 128 | False | 0.039 | 0.033 | 16.820 | 1624.567 | 1625.296 | -0.045 |
|
||||||
|
| 100 | 2 | 256 | False | 0.065 | 0.059 | 10.255 | 1672.164 | 1672.164 | 0.000 |
|
||||||
|
| 100 | 4 | 128 | False | 0.062 | 0.058 | 6.998 | 1671.435 | 1672.164 | -0.044 |
|
||||||
|
| 100 | 4 | 256 | False | 0.113 | 0.100 | 13.316 | 2350.179 | 1848.435 | 27.144 |
|
||||||
|
| 100 | 8 | 128 | False | 0.107 | 0.098 | 9.883 | 2098.521 | 1848.435 | 13.530 |
|
||||||
|
| 100 | 8 | 256 | False | 0.222 | 0.196 | 13.413 | 3989.980 | 2986.492 | 33.601 |
|
||||||
|
|
||||||
|
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and `microsoft/biogpt` model with a simple AutoModel head,
|
||||||
|
we saw the following speedups during inference.
|
||||||
|
|
||||||
|
| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|
||||||
|
|-------------|------------|---------|---------|---------|----------|------------------------------|-----------------------------|-------------|----------------|--------------|---------------|
|
||||||
|
| 50 | 1 | 64 | True | True | True | 0.115 | 0.098 | 17.392 | 716.998 | 716.998 | 0.000 |
|
||||||
|
| 50 | 1 | 128 | True | True | True | 0.115 | 0.093 | 24.640 | 730.916 | 730.916 | 0.000 |
|
||||||
|
| 50 | 2 | 64 | True | True | True | 0.114 | 0.096 | 19.204 | 730.900 | 730.900 | 0.000 |
|
||||||
|
| 50 | 2 | 128 | True | True | True | 0.117 | 0.095 | 23.529 | 759.262 | 759.262 | 0.000 |
|
||||||
|
| 50 | 4 | 64 | True | True | True | 0.113 | 0.096 | 18.325 | 759.229 | 759.229 | 0.000 |
|
||||||
|
| 50 | 4 | 128 | True | True | True | 0.186 | 0.178 | 4.289 | 816.478 | 816.478 | 0.000 |
|
||||||
|
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||||
|
|||||||
@@ -208,6 +208,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
||||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||||
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
||||||
|
* [BioGpt](https://huggingface.co/docs/transformers/model_doc/biogpt#transformers.BioGptModel)
|
||||||
* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel)
|
* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel)
|
||||||
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
||||||
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
@@ -244,16 +244,130 @@ class BioGptAttention(nn.Module):
|
|||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->BioGpt
|
||||||
|
class BioGptSdpaAttention(BioGptAttention):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
if output_attentions or layer_head_mask is not None:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||||
|
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states,
|
||||||
|
key_value_states=key_value_states,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0]
|
||||||
|
value_states = past_key_value[1]
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz)
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||||
|
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
||||||
|
|
||||||
|
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||||
|
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
BIOGPT_ATTENTION_CLASSES = {
|
||||||
|
"eager": BioGptAttention,
|
||||||
|
"sdpa": BioGptSdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class BioGptDecoderLayer(nn.Module):
|
class BioGptDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: BioGptConfig):
|
def __init__(self, config: BioGptConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = BioGptAttention(
|
self.self_attn = BIOGPT_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
embed_dim=self.embed_dim,
|
embed_dim=self.embed_dim,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
dropout=config.attention_probs_dropout_prob,
|
dropout=config.attention_probs_dropout_prob,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
|
is_causal=True,
|
||||||
)
|
)
|
||||||
self.dropout = config.hidden_dropout_prob
|
self.dropout = config.hidden_dropout_prob
|
||||||
self.activation_fn = ACT2FN[config.hidden_act]
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
@@ -337,6 +451,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = BioGptConfig
|
config_class = BioGptConfig
|
||||||
base_model_prefix = "biogpt"
|
base_model_prefix = "biogpt"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
@@ -444,6 +559,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
|||||||
self.layer_norm = nn.LayerNorm(self.embed_dim)
|
self.layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@@ -511,9 +627,16 @@ class BioGptModel(BioGptPreTrainedModel):
|
|||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(attention_mask, past_key_values_length)
|
positions = self.embed_positions(attention_mask, past_key_values_length)
|
||||||
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
if self._use_sdpa and not output_attentions and head_mask is None:
|
||||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||||
)
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user