Add sdpa for Vivit (#33757)
* chore:add sdpa to vivit * fix:failing slow test_inference_interpolate_pos_encoding(failing on main branch too) * chore:fix nits * ci:fix repo consistency failure * chore:add info and benchmark to model doc * [run_slow] vivit * chore:revert interpolation test fix for new issue * [run_slow] vivit * [run_slow] vivit * [run_slow] vivit * chore:add fallback for output_attentions being True * [run_slow] vivit * style:make fixup * [run_slow] vivit
This commit is contained in:
@@ -23,6 +23,43 @@ The abstract from the paper is the following:
|
||||
|
||||
This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit).
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
from transformers import VivitModel
|
||||
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400", attn_implementation="sdpa", torch_dtype=torch.float16)
|
||||
...
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vivit-b-16x2-kinetics400` model, we saw the following speedups during inference.
|
||||
|
||||
### Training
|
||||
| num_training_steps | batch_size | is cuda | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|
||||
|---------------------:|-------------:|----------:|--------------:|----------------------:|---------------------:|-----------------:|
|
||||
| 100 | 1 | True | 7.122 | 2575.28 | 5932.54 | 130.364 |
|
||||
|
||||
|
||||
|
||||
### Inference
|
||||
| num_batches | batch_size | is cuda | is half | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|
||||
|---------------|--------------|-----------|-----------|---------------|------------------|---------------|-----------------|
|
||||
| 20 | 1 | True | False | 15.422 | 715.807 | 317.079 | 125.75 |
|
||||
| 20 | 2 | True | False | 17.146 | 1234.75 | 447.175 | 176.122 |
|
||||
| 20 | 4 | True | False | 18.093 | 2275.82 | 709.864 | 220.6 |
|
||||
| 20 | 8 | True | False | 19.284 | 4358.19 | 1233.24 | 253.393 |
|
||||
|
||||
|
||||
## VivitConfig
|
||||
|
||||
[[autodoc]] VivitConfig
|
||||
|
||||
@@ -278,6 +278,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
|
||||
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
|
||||
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
|
||||
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
|
||||
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
|
||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
|
||||
|
||||
@@ -227,6 +227,51 @@ class VivitSelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
|
||||
class VivitSdpaSelfAttention(VivitSelfAttention):
|
||||
def __init__(self, config: VivitConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
logger.warning_once(
|
||||
"VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
|
||||
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
|
||||
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
|
||||
' removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
|
||||
class VivitSelfOutput(nn.Module):
|
||||
"""
|
||||
@@ -286,6 +331,13 @@ class VivitAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
|
||||
class VivitSdpaAttention(VivitAttention):
|
||||
def __init__(self, config: VivitConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = VivitSdpaSelfAttention(config)
|
||||
|
||||
|
||||
class VivitIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -320,6 +372,12 @@ class VivitOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
VIVIT_ATTENTION_CLASSES = {
|
||||
"eager": VivitAttention,
|
||||
"sdpa": VivitSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class VivitLayer(nn.Module):
|
||||
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
|
||||
|
||||
@@ -327,7 +385,7 @@ class VivitLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = VivitAttention(config)
|
||||
self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.intermediate = VivitIntermediate(config)
|
||||
self.output = VivitOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -436,6 +494,7 @@ class VivitPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -65,6 +65,8 @@ class VivitModelTester:
|
||||
layer_norm_eps=1e-06,
|
||||
qkv_bias=True,
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
mask_ratio=0.5,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -86,12 +88,15 @@ class VivitModelTester:
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.qkv_bias = qkv_bias
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
self.seq_length = (
|
||||
(self.image_size // self.tubelet_size[2])
|
||||
* (self.image_size // self.tubelet_size[1])
|
||||
* (self.num_frames // self.tubelet_size[0])
|
||||
) + 1 # CLS token
|
||||
self.mask_ratio = mask_ratio
|
||||
self.num_masks = int(mask_ratio * self.seq_length)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
@@ -122,6 +127,7 @@ class VivitModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
qkv_bias=self.qkv_bias,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
config.num_labels = self.num_labels
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user