Add sdpa for Vivit (#33757)
* chore:add sdpa to vivit * fix:failing slow test_inference_interpolate_pos_encoding(failing on main branch too) * chore:fix nits * ci:fix repo consistency failure * chore:add info and benchmark to model doc * [run_slow] vivit * chore:revert interpolation test fix for new issue * [run_slow] vivit * [run_slow] vivit * [run_slow] vivit * chore:add fallback for output_attentions being True * [run_slow] vivit * style:make fixup * [run_slow] vivit
This commit is contained in:
@@ -23,6 +23,43 @@ The abstract from the paper is the following:
|
|||||||
|
|
||||||
This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit).
|
This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit).
|
||||||
|
|
||||||
|
### Using Scaled Dot Product Attention (SDPA)
|
||||||
|
|
||||||
|
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||||
|
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||||
|
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||||
|
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||||
|
page for more information.
|
||||||
|
|
||||||
|
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||||
|
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||||
|
|
||||||
|
```
|
||||||
|
from transformers import VivitModel
|
||||||
|
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400", attn_implementation="sdpa", torch_dtype=torch.float16)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||||
|
|
||||||
|
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vivit-b-16x2-kinetics400` model, we saw the following speedups during inference.
|
||||||
|
|
||||||
|
### Training
|
||||||
|
| num_training_steps | batch_size | is cuda | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|
||||||
|
|---------------------:|-------------:|----------:|--------------:|----------------------:|---------------------:|-----------------:|
|
||||||
|
| 100 | 1 | True | 7.122 | 2575.28 | 5932.54 | 130.364 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
| num_batches | batch_size | is cuda | is half | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|
||||||
|
|---------------|--------------|-----------|-----------|---------------|------------------|---------------|-----------------|
|
||||||
|
| 20 | 1 | True | False | 15.422 | 715.807 | 317.079 | 125.75 |
|
||||||
|
| 20 | 2 | True | False | 17.146 | 1234.75 | 447.175 | 176.122 |
|
||||||
|
| 20 | 4 | True | False | 18.093 | 2275.82 | 709.864 | 220.6 |
|
||||||
|
| 20 | 8 | True | False | 19.284 | 4358.19 | 1233.24 | 253.393 |
|
||||||
|
|
||||||
|
|
||||||
## VivitConfig
|
## VivitConfig
|
||||||
|
|
||||||
[[autodoc]] VivitConfig
|
[[autodoc]] VivitConfig
|
||||||
|
|||||||
@@ -278,6 +278,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
|
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
|
||||||
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
|
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
|
||||||
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
|
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
|
||||||
|
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
|
||||||
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
|
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
|
||||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||||
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
|
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
|
||||||
|
|||||||
@@ -227,6 +227,51 @@ class VivitSelfAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
|
||||||
|
class VivitSdpaSelfAttention(VivitSelfAttention):
|
||||||
|
def __init__(self, config: VivitConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||||
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
|
if output_attentions or head_mask is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
|
||||||
|
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
|
||||||
|
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
|
||||||
|
' removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
head_mask,
|
||||||
|
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||||
|
is_causal=False,
|
||||||
|
scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
|
return context_layer, None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
|
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
|
||||||
class VivitSelfOutput(nn.Module):
|
class VivitSelfOutput(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -286,6 +331,13 @@ class VivitAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
|
||||||
|
class VivitSdpaAttention(VivitAttention):
|
||||||
|
def __init__(self, config: VivitConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.attention = VivitSdpaSelfAttention(config)
|
||||||
|
|
||||||
|
|
||||||
class VivitIntermediate(nn.Module):
|
class VivitIntermediate(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -320,6 +372,12 @@ class VivitOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
VIVIT_ATTENTION_CLASSES = {
|
||||||
|
"eager": VivitAttention,
|
||||||
|
"sdpa": VivitSdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class VivitLayer(nn.Module):
|
class VivitLayer(nn.Module):
|
||||||
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
|
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
|
||||||
|
|
||||||
@@ -327,7 +385,7 @@ class VivitLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
self.seq_len_dim = 1
|
self.seq_len_dim = 1
|
||||||
self.attention = VivitAttention(config)
|
self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||||
self.intermediate = VivitIntermediate(config)
|
self.intermediate = VivitIntermediate(config)
|
||||||
self.output = VivitOutput(config)
|
self.output = VivitOutput(config)
|
||||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
@@ -436,6 +494,7 @@ class VivitPreTrainedModel(PreTrainedModel):
|
|||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = []
|
_no_split_modules = []
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -65,6 +65,8 @@ class VivitModelTester:
|
|||||||
layer_norm_eps=1e-06,
|
layer_norm_eps=1e-06,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
attn_implementation="eager",
|
||||||
|
mask_ratio=0.5,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -86,12 +88,15 @@ class VivitModelTester:
|
|||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.qkv_bias = qkv_bias
|
self.qkv_bias = qkv_bias
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
|
||||||
self.seq_length = (
|
self.seq_length = (
|
||||||
(self.image_size // self.tubelet_size[2])
|
(self.image_size // self.tubelet_size[2])
|
||||||
* (self.image_size // self.tubelet_size[1])
|
* (self.image_size // self.tubelet_size[1])
|
||||||
* (self.num_frames // self.tubelet_size[0])
|
* (self.num_frames // self.tubelet_size[0])
|
||||||
) + 1 # CLS token
|
) + 1 # CLS token
|
||||||
|
self.mask_ratio = mask_ratio
|
||||||
|
self.num_masks = int(mask_ratio * self.seq_length)
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor(
|
pixel_values = floats_tensor(
|
||||||
@@ -122,6 +127,7 @@ class VivitModelTester:
|
|||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
layer_norm_eps=self.layer_norm_eps,
|
layer_norm_eps=self.layer_norm_eps,
|
||||||
qkv_bias=self.qkv_bias,
|
qkv_bias=self.qkv_bias,
|
||||||
|
attn_implementation=self.attn_implementation,
|
||||||
)
|
)
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
return config
|
return config
|
||||||
|
|||||||
Reference in New Issue
Block a user