Add sdpa for Vivit (#33757)

* chore:add sdpa to vivit

* fix:failing slow test_inference_interpolate_pos_encoding(failing on main branch too)

* chore:fix nits

* ci:fix repo consistency failure

* chore:add info and benchmark to model doc

* [run_slow] vivit

* chore:revert interpolation test fix for new issue

* [run_slow] vivit

* [run_slow] vivit

* [run_slow] vivit

* chore:add fallback for output_attentions being True

* [run_slow] vivit

* style:make fixup

* [run_slow] vivit
This commit is contained in:
Prakarsh Kaushik
2024-10-15 14:57:54 +05:30
committed by GitHub
parent 23874f5948
commit 293e6271c6
4 changed files with 104 additions and 1 deletions

View File

@@ -23,6 +23,43 @@ The abstract from the paper is the following:
This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit). This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit).
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import VivitModel
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vivit-b-16x2-kinetics400` model, we saw the following speedups during inference.
### Training
| num_training_steps | batch_size | is cuda | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|---------------------:|-------------:|----------:|--------------:|----------------------:|---------------------:|-----------------:|
| 100 | 1 | True | 7.122 | 2575.28 | 5932.54 | 130.364 |
### Inference
| num_batches | batch_size | is cuda | is half | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|---------------|--------------|-----------|-----------|---------------|------------------|---------------|-----------------|
| 20 | 1 | True | False | 15.422 | 715.807 | 317.079 | 125.75 |
| 20 | 2 | True | False | 17.146 | 1234.75 | 447.175 | 176.122 |
| 20 | 4 | True | False | 18.093 | 2275.82 | 709.864 | 220.6 |
| 20 | 8 | True | False | 19.284 | 4358.19 | 1233.24 | 253.393 |
## VivitConfig ## VivitConfig
[[autodoc]] VivitConfig [[autodoc]] VivitConfig

View File

@@ -278,6 +278,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel) * [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel) * [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell) * [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) * [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)

View File

@@ -227,6 +227,51 @@ class VivitSelfAttention(nn.Module):
return outputs return outputs
# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
class VivitSdpaSelfAttention(VivitSelfAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
' removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
head_mask,
output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
class VivitSelfOutput(nn.Module): class VivitSelfOutput(nn.Module):
""" """
@@ -286,6 +331,13 @@ class VivitAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
class VivitSdpaAttention(VivitAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention = VivitSdpaSelfAttention(config)
class VivitIntermediate(nn.Module): class VivitIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -320,6 +372,12 @@ class VivitOutput(nn.Module):
return hidden_states return hidden_states
VIVIT_ATTENTION_CLASSES = {
"eager": VivitAttention,
"sdpa": VivitSdpaAttention,
}
class VivitLayer(nn.Module): class VivitLayer(nn.Module):
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation.""" """This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
@@ -327,7 +385,7 @@ class VivitLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = VivitAttention(config) self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VivitIntermediate(config) self.intermediate = VivitIntermediate(config)
self.output = VivitOutput(config) self.output = VivitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -436,6 +494,7 @@ class VivitPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = [] _no_split_modules = []
_supports_sdpa = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -65,6 +65,8 @@ class VivitModelTester:
layer_norm_eps=1e-06, layer_norm_eps=1e-06,
qkv_bias=True, qkv_bias=True,
scope=None, scope=None,
attn_implementation="eager",
mask_ratio=0.5,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@@ -86,12 +88,15 @@ class VivitModelTester:
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias self.qkv_bias = qkv_bias
self.scope = scope self.scope = scope
self.attn_implementation = attn_implementation
self.seq_length = ( self.seq_length = (
(self.image_size // self.tubelet_size[2]) (self.image_size // self.tubelet_size[2])
* (self.image_size // self.tubelet_size[1]) * (self.image_size // self.tubelet_size[1])
* (self.num_frames // self.tubelet_size[0]) * (self.num_frames // self.tubelet_size[0])
) + 1 # CLS token ) + 1 # CLS token
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor( pixel_values = floats_tensor(
@@ -122,6 +127,7 @@ class VivitModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
layer_norm_eps=self.layer_norm_eps, layer_norm_eps=self.layer_norm_eps,
qkv_bias=self.qkv_bias, qkv_bias=self.qkv_bias,
attn_implementation=self.attn_implementation,
) )
config.num_labels = self.num_labels config.num_labels = self.num_labels
return config return config