Sdpa dino v2 (#33403)
* add sdpa to dinov2 * fixup * add dinov2 to sdpa doc * update doc order * [run-slow] dinov2 * common to eager * [run-slow] dinov2 * update attn implementation in common * update test_modeling_dinov2 to have mask_ration, num_masks and mask_length similar to vit * [run-slow] dinov2 --------- Co-authored-by: Avishai Elmakies <avishai.elma@cs.huji.ac.il>
This commit is contained in:
@@ -217,6 +217,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
||||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||||
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
||||||
|
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
|
||||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||||
@@ -275,7 +276,6 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
||||||
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
||||||
|
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
|
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
|
||||||
|
|||||||
@@ -231,6 +231,38 @@ class Dinov2SelfAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2
|
||||||
|
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
|
||||||
|
def __init__(self, config: Dinov2Config) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||||
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
head_mask,
|
||||||
|
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||||
|
is_causal=False,
|
||||||
|
scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|
||||||
|
return context_layer, None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
||||||
class Dinov2SelfOutput(nn.Module):
|
class Dinov2SelfOutput(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -290,6 +322,13 @@ class Dinov2Attention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2
|
||||||
|
class Dinov2SdpaAttention(Dinov2Attention):
|
||||||
|
def __init__(self, config: Dinov2Config) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.attention = Dinov2SdpaSelfAttention(config)
|
||||||
|
|
||||||
|
|
||||||
class Dinov2LayerScale(nn.Module):
|
class Dinov2LayerScale(nn.Module):
|
||||||
def __init__(self, config) -> None:
|
def __init__(self, config) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -371,6 +410,12 @@ class Dinov2SwiGLUFFN(nn.Module):
|
|||||||
return self.weights_out(hidden)
|
return self.weights_out(hidden)
|
||||||
|
|
||||||
|
|
||||||
|
DINOV2_ATTENTION_CLASSES = {
|
||||||
|
"eager": Dinov2Attention,
|
||||||
|
"sdpa": Dinov2SdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Dinov2Layer(nn.Module):
|
class Dinov2Layer(nn.Module):
|
||||||
"""This corresponds to the Block class in the original implementation."""
|
"""This corresponds to the Block class in the original implementation."""
|
||||||
|
|
||||||
@@ -378,7 +423,7 @@ class Dinov2Layer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attention = Dinov2Attention(config)
|
self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||||
self.layer_scale1 = Dinov2LayerScale(config)
|
self.layer_scale1 = Dinov2LayerScale(config)
|
||||||
self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||||
|
|
||||||
@@ -485,6 +530,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
|
|||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["Dinov2SwiGLUFFN"]
|
_no_split_modules = ["Dinov2SwiGLUFFN"]
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -65,6 +65,8 @@ class Dinov2ModelTester:
|
|||||||
type_sequence_label_size=10,
|
type_sequence_label_size=10,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
attn_implementation="eager",
|
||||||
|
mask_ratio=0.5,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -83,10 +85,14 @@ class Dinov2ModelTester:
|
|||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
self.mask_ratio = mask_ratio
|
||||||
|
|
||||||
# in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
# in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||||
num_patches = (image_size // patch_size) ** 2
|
num_patches = (image_size // patch_size) ** 2
|
||||||
self.seq_length = num_patches + 1
|
self.seq_length = num_patches + 1
|
||||||
|
self.num_masks = int(self.mask_ratio * self.seq_length)
|
||||||
|
self.mask_length = num_patches
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
@@ -113,6 +119,7 @@ class Dinov2ModelTester:
|
|||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
is_decoder=False,
|
is_decoder=False,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
|
attn_implementation=self.attn_implementation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values, labels):
|
||||||
|
|||||||
Reference in New Issue
Block a user