🔴 Update CLIP vision attention to new attention interface (#37498)
* update attention interface * fix test * propagate attention changes * revert weird changes * fix modular * what? * ruff is mocking me * ruff being ruff * simplify test suite + fix FA2 * fixup tests + propagate FA2 fixes * add Copied From where relevant * fix conflict between copies and modular * recover FA2 training for CLIP + handle quantization * don't ditch the warning * tiny import fix * code review (FA2 support, copied from) * fix style * modularity * wrong copies * future-proofing for TP * mlcd inherits from CLIP
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -30,9 +30,15 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndProjection,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, torch_int
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
|
||||
|
||||
|
||||
@@ -721,7 +727,29 @@ class AltRobertaPooler(nn.Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->AltCLIP
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AltCLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@@ -738,15 +766,13 @@ class AltCLIPAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -756,74 +782,51 @@ class AltCLIPAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
||||
if self.config._attn_implementation != "flash_attention_2":
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attention_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attention_mask = causal_attention_mask
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
self.is_causal = causal_attention_mask is not None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->AltCLIP
|
||||
|
||||
@@ -15,19 +15,16 @@
|
||||
"""PyTorch CLIP model."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@@ -41,10 +38,6 @@ from ...utils import (
|
||||
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
@@ -297,10 +290,34 @@ class CLIPTextEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
output_attentions: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -313,15 +330,13 @@ class CLIPAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -331,242 +346,55 @@ class CLIPAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
self.is_causal = causal_attention_mask is not None
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
|
||||
|
||||
class CLIPFlashAttention2(CLIPAttention):
|
||||
"""
|
||||
CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
output_attentions = False
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
dropout_rate = self.dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32.
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attention_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attention_mask = causal_attention_mask
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=causal_attention_mask is not None,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class CLIPSdpaAttention(CLIPAttention):
|
||||
"""
|
||||
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from CLIPAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
|
||||
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
|
||||
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
|
||||
'be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attn_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attn_mask = causal_attention_mask
|
||||
else:
|
||||
attn_mask = attention_mask
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
CLIP_ATTENTION_CLASSES = {
|
||||
"eager": CLIPAttention,
|
||||
"sdpa": CLIPSdpaAttention,
|
||||
"flash_attention_2": CLIPFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -583,10 +411,10 @@ class CLIPMLP(nn.Module):
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
def __init__(self, config: CLIPConfig):
|
||||
def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.self_attn = CLIPAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
@@ -952,7 +780,7 @@ class CLIPTextTransformer(nn.Module):
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None and not self._use_flash_attention_2:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
encoder_outputs: BaseModelOutput = self.encoder(
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import copy
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -26,7 +26,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@@ -264,11 +264,33 @@ class CLIPSegTextEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->CLIPSeg
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class CLIPSegAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Union[CLIPSegVisionConfig, CLIPSegTextConfig]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -281,15 +303,13 @@ class CLIPSegAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -299,74 +319,52 @@ class CLIPSegAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
||||
if self.config._attn_implementation != "flash_attention_2":
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attention_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attention_mask = causal_attention_mask
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
self.is_causal = causal_attention_mask is not None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->CLIPSeg
|
||||
|
||||
@@ -303,7 +303,6 @@ class ClvpSelfAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention._shape
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -25,7 +25,6 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...file_utils import ModelOutput
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
@@ -34,9 +33,10 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPooling,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
@@ -698,7 +698,6 @@ class GitVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP
|
||||
class GitVisionMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -714,7 +713,29 @@ class GitVisionMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->GitVision
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GitVisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@@ -731,15 +752,13 @@ class GitVisionAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -749,74 +768,51 @@ class GitVisionAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
||||
if self.config._attn_implementation != "flash_attention_2":
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attention_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attention_mask = causal_attention_mask
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
self.is_causal = causal_attention_mask is not None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -24,7 +24,11 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...utils import ModelOutput, logging
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
logging,
|
||||
)
|
||||
from .configuration_idefics import IdeficsVisionConfig
|
||||
|
||||
|
||||
@@ -160,11 +164,33 @@ class IdeficsVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class IdeficsVisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: IdeficsVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -177,15 +203,13 @@ class IdeficsVisionAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -195,74 +219,51 @@ class IdeficsVisionAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
||||
if self.config._attn_implementation != "flash_attention_2":
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attention_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attention_mask = causal_attention_mask
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
self.is_causal = causal_attention_mask is not None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -31,7 +31,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPooling,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@@ -466,7 +466,29 @@ class Kosmos2VisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Kosmos2Vision
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Kosmos2VisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@@ -483,15 +505,13 @@ class Kosmos2VisionAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -501,74 +521,51 @@ class Kosmos2VisionAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
||||
if self.config._attn_implementation != "flash_attention_2":
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attention_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attention_mask = causal_attention_mask
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
self.is_causal = causal_attention_mask is not None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision
|
||||
|
||||
@@ -172,25 +172,6 @@ class MLCDVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
@@ -217,6 +198,25 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -253,16 +253,13 @@ class MLCDAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.num_key_value_groups = config.num_key_value_groups
|
||||
self.is_causal = False
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch X-CLIP model."""
|
||||
|
||||
from copy import copy
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -25,7 +25,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@@ -223,7 +223,29 @@ class XCLIPTextEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->XCLIP
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class XCLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@@ -240,15 +262,13 @@ class XCLIPAttention(nn.Module):
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -258,77 +278,54 @@ class XCLIPAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
||||
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
||||
if self.config._attn_implementation != "flash_attention_2":
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attention_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attention_mask = causal_attention_mask
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
self.is_causal = causal_attention_mask is not None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->XCLIP
|
||||
class XCLIPMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -1330,8 +1327,7 @@ class XCLIPModel(XCLIPPreTrainedModel):
|
||||
|
||||
self.prompts_visual_layernorm = nn.LayerNorm(self.vision_embed_dim, eps=config.vision_config.layer_norm_eps)
|
||||
self.prompts_visual_projection = nn.Parameter(torch.randn(self.vision_embed_dim, self.projection_dim))
|
||||
|
||||
mit_config = copy(vision_config)
|
||||
mit_config = copy.copy(vision_config)
|
||||
mit_config.hidden_size = vision_config.mit_hidden_size
|
||||
mit_config.intermediate_size = vision_config.mit_intermediate_size
|
||||
mit_config.num_hidden_layers = vision_config.mit_num_hidden_layers
|
||||
|
||||
@@ -17,7 +17,6 @@ import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -36,14 +35,12 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_torch_available,
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_sdpa_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
@@ -67,11 +64,6 @@ if is_torch_available():
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_sdpa_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@@ -170,6 +162,11 @@ class CLIPVisionModelTester:
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
def test_eager_matches_sdpa_inference(self, *args):
|
||||
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
||||
|
||||
|
||||
class CLIPModelTesterMixin(ModelTesterMixin):
|
||||
"""
|
||||
@@ -178,6 +175,7 @@ class CLIPModelTesterMixin(ModelTesterMixin):
|
||||
different output logits, and are not supposed to be used or tested with padding_side="left".
|
||||
"""
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -186,8 +184,8 @@ class CLIPModelTesterMixin(ModelTesterMixin):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# Load the model with SDPA
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
# Load the model with SDPA (it is the default, but we explicit it for clarity)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
# Load model with eager attention
|
||||
@@ -197,180 +195,17 @@ class CLIPModelTesterMixin(ModelTesterMixin):
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
# SigLip has one shared cls attr for all models, so we assign both submodels heer
|
||||
vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager"
|
||||
|
||||
# `None` as it is the requested one which will be assigned to each sub-config
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"):
|
||||
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
|
||||
self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn)
|
||||
if hasattr(model_sdpa, "vision_model"):
|
||||
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
|
||||
|
||||
if hasattr(model_sdpa, "text_model"):
|
||||
self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self,
|
||||
torch_dtype: str,
|
||||
use_attention_mask_options: tuple[Optional[str], ...] = (None, "left", "right"),
|
||||
logit_keys: tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
|
||||
):
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
|
||||
self.skipTest(
|
||||
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
|
||||
)
|
||||
|
||||
# Convert to torch dtype
|
||||
dtypes = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtypes[torch_dtype]
|
||||
|
||||
atols = {
|
||||
torch.float32: 1e-5,
|
||||
torch.bfloat16: 3e-2,
|
||||
torch.float16: 5e-3,
|
||||
}
|
||||
rtols = {
|
||||
torch.float32: 1e-4,
|
||||
torch.bfloat16: 3e-2,
|
||||
torch.float16: 5e-3,
|
||||
}
|
||||
|
||||
atol = atols[torch_dtype]
|
||||
rtol = rtols[torch_dtype]
|
||||
|
||||
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
|
||||
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# Load the model with SDPA
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
# Load model with eager attention
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
|
||||
# but it would be nicer to have an efficient way to use parameterized.expand
|
||||
cases = [
|
||||
(use_mask, output_attentions, sdpa_backend, batch_size)
|
||||
for use_mask in use_attention_mask_options
|
||||
for output_attentions in [True, False]
|
||||
for sdpa_backend in [
|
||||
[SDPBackend.MATH],
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
|
||||
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
|
||||
]
|
||||
for batch_size in [1, 5]
|
||||
]
|
||||
fail_cases = []
|
||||
|
||||
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
|
||||
processed_inputs = inputs_dict.copy()
|
||||
|
||||
# convert to torch_dtype
|
||||
if "pixel_values" in processed_inputs:
|
||||
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
|
||||
|
||||
# slice for different batch sizes
|
||||
for key in ["pixel_values", "input_ids", "attention_mask"]:
|
||||
if key in processed_inputs:
|
||||
processed_inputs[key] = processed_inputs[key][:batch_size]
|
||||
|
||||
# set attention mask with left padding
|
||||
if not use_mask:
|
||||
processed_inputs.pop("attention_mask", None)
|
||||
elif use_mask == "left":
|
||||
dummy_attention_mask = processed_inputs["attention_mask"]
|
||||
dummy_attention_mask[:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
elif use_mask == "right":
|
||||
dummy_attention_mask = processed_inputs["attention_mask"]
|
||||
dummy_attention_mask[:] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
else:
|
||||
raise ValueError(f"Invalid value for use_mask={use_mask}")
|
||||
|
||||
processed_inputs["output_attentions"] = output_attentions
|
||||
processed_inputs["output_hidden_states"] = True
|
||||
|
||||
current_case = f"use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
|
||||
|
||||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
with sdpa_kernel(sdpa_backend):
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
except Exception as e:
|
||||
fail_cases.append(f"{current_case}: {e}")
|
||||
continue
|
||||
|
||||
keys = set(logit_keys) & set(outputs_eager.keys())
|
||||
self.assertTrue(
|
||||
keys, f"Keys {logit_keys} not found in outputs. Available keys: {outputs_eager.keys()}"
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
try:
|
||||
eager_logits = outputs_eager[key]
|
||||
sdpa_logits = outputs_sdpa[key]
|
||||
except KeyError:
|
||||
raise KeyError(f"Key {key} not found in outputs. Available keys: {outputs_eager.keys()}")
|
||||
|
||||
if "hidden_state" in key and use_mask == "left":
|
||||
eager_logits = eager_logits[:, 1:]
|
||||
sdpa_logits = sdpa_logits[:, 1:]
|
||||
elif "hidden_state" in key and use_mask == "right":
|
||||
eager_logits = eager_logits[:, :-1]
|
||||
sdpa_logits = sdpa_logits[:, :-1]
|
||||
|
||||
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
|
||||
if not is_close:
|
||||
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
|
||||
|
||||
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
|
||||
|
||||
|
||||
@require_torch
|
||||
class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase):
|
||||
@@ -458,16 +293,12 @@ class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertTrue(hasattr(model, "visual_projection"))
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("last_hidden_state", "pooler_output", "image_embeds"),
|
||||
use_attention_mask_options=(None,),
|
||||
)
|
||||
def test_eager_matches_sdpa_inference(self, *args):
|
||||
# adding only flaky decorator here and call the parent test method
|
||||
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
@@ -632,16 +463,13 @@ class CLIPTextModelTest(CLIPModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertTrue(hasattr(model, "text_projection"))
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("last_hidden_state", "pooler_output", "text_embeds"),
|
||||
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
|
||||
)
|
||||
def test_eager_matches_sdpa_inference(self, *args):
|
||||
# adding only flaky decorator here and call the parent test method
|
||||
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
@@ -860,16 +688,13 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
model = CLIPModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("logits_per_image", "logits_per_text"),
|
||||
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
|
||||
)
|
||||
def test_eager_matches_sdpa_inference(self, *args):
|
||||
# adding only flaky decorator here and call the parent test method
|
||||
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
@@ -1033,16 +858,13 @@ class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMi
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("logits",),
|
||||
use_attention_mask_options=(None,),
|
||||
)
|
||||
def test_eager_matches_sdpa_inference(self, *args):
|
||||
# adding only flaky decorator here and call the parent test method
|
||||
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
@@ -1062,7 +884,7 @@ class CLIPModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model_name = "openai/clip-vit-base-patch32"
|
||||
model = CLIPModel.from_pretrained(model_name).to(torch_device)
|
||||
model = CLIPModel.from_pretrained(model_name, attn_implementation="sdpa").to(torch_device)
|
||||
processor = CLIPProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
@@ -1122,5 +944,5 @@ class CLIPModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
torch.testing.assert_close(
|
||||
outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4
|
||||
outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=6e-3, atol=4e-4
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user