Gemma capping (#34282)
* softcapping * soft cap before the mask * style * ... * super nit * update * fixes * update * small issue with modular * fix modular imports * update * fixup * simplify a hell lot * simplify cleaning imports * finish fixing * update our design * nits * use a deprecation cycle * updates * Fix modular (recursive deps need to always be computed after merges!) * push * fix * update * fix modular order * make fix-copies * updates * update * ? * don't compile for now * ? * fix some stuff * donc! * fix copies * update * fixup * ? * fix two tests * fix? * for now, don't use head info * eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :)) * fix-copies * revert sdpa check * Apply suggestions from code review Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * rebase, fix-copies and push * add a slow integration test * update the test * fix left padding issue * fix test * remove duplicate scaling * quality * add a small test and make sure it works * 2b --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -1519,6 +1519,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"eager",
|
||||
"sdpa",
|
||||
"flash_attention_2",
|
||||
"flex_attention",
|
||||
]:
|
||||
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
if cls._supports_flash_attn_2:
|
||||
|
||||
@@ -41,7 +41,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_greater_or_equal,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -51,6 +51,8 @@ from .configuration_gemma2 import Gemma2Config
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
if is_torch_greater_or_equal("2.5"):
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -168,6 +170,127 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
key_states = repeat_kv(key, config.num_key_value_groups)
|
||||
value_states = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling
|
||||
|
||||
if config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * config.attn_logit_softcapping
|
||||
if mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
|
||||
if mask is not None:
|
||||
seq_len = mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
value = value[:, :, :seq_len]
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
|
||||
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
|
||||
query_states = query.transpose(1, 2)
|
||||
key_states = key.transpose(1, 2)
|
||||
value_states = value.transpose(1, 2)
|
||||
|
||||
dropout_rate = config.attention_dropout if config.training else 0.0
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
mask,
|
||||
seq_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=config.scaling,
|
||||
is_causal=config.is_causal,
|
||||
sliding_window=config.sliding_window,
|
||||
use_top_left_mask=config._flash_attn_uses_top_left_mask,
|
||||
softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
|
||||
def tanh_softcap(score, b, h, q_idx, kv_idx):
|
||||
soft_cap = config.attn_logit_softcapping
|
||||
score = soft_cap * torch.tanh(score / soft_cap)
|
||||
if mask is not None:
|
||||
return score + mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output = flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=tanh_softcap,
|
||||
enable_gqa=True,
|
||||
scale=config.scaling,
|
||||
return_lse=output_attentions,
|
||||
)
|
||||
if not output_attentions:
|
||||
return attn_output, None
|
||||
else:
|
||||
return attn_output[0], attn_output[1]
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
key = repeat_kv(key, config.num_key_value_groups)
|
||||
value = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
causal_mask = mask
|
||||
if mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query.device.type == "cuda" and causal_mask is not None:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=config.attention_dropout if config.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=config.scaling,
|
||||
)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
GEMMA2_ATTENTION_FUNCTION = {
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
"eager": eager_attention_forward,
|
||||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
|
||||
|
||||
class Gemma2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@@ -175,12 +298,6 @@ class Gemma2Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -192,7 +309,8 @@ class Gemma2Attention(nn.Module):
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
self.attn_logit_softcapping = config.attn_logit_softcapping
|
||||
if self.hidden_size % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
@@ -208,7 +326,6 @@ class Gemma2Attention(nn.Module):
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -243,145 +360,14 @@ class Gemma2Attention(nn.Module):
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Gemma2FlashAttention2(Gemma2Attention):
|
||||
"""
|
||||
Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask.shape[1]
|
||||
key_states = key_states[:, :, :seq_len]
|
||||
value_states = value_states[:, :, :seq_len]
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (Gemma2RMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
|
||||
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
|
||||
attention_type = "eager"
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
attention_type = self.config._attn_implementation
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
is_causal=self.is_causal,
|
||||
sliding_window=self.sliding_window,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
|
||||
attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
|
||||
self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
@@ -393,116 +379,37 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Gemma2SdpaAttention(Gemma2Attention):
|
||||
"""
|
||||
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Gemma2Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
class Gemma2FlashAttention2(Gemma2Attention):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.config._attn_implementation = "flash_attention_2"
|
||||
logger.warning_once(
|
||||
"Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
|
||||
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
class Gemma2SdpaAttention(Gemma2Attention):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.config._attn_implementation = "sdpa"
|
||||
logger.warning_once(
|
||||
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
|
||||
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
GEMMA2_ATTENTION_CLASSES = {
|
||||
"eager": Gemma2Attention,
|
||||
"flash_attention_2": Gemma2FlashAttention2,
|
||||
"sdpa": Gemma2SdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
self.config = config
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma2MLP(config)
|
||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.config = config
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
@@ -517,25 +424,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
|
||||
@@ -29,18 +29,17 @@ from ...modeling_outputs import (
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_greater_or_equal,
|
||||
logging,
|
||||
)
|
||||
from ..gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaDecoderLayer,
|
||||
GemmaForCausalLM,
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaPreTrainedModel,
|
||||
GemmaRMSNorm,
|
||||
GemmaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
@@ -49,6 +48,9 @@ from ..gemma.modeling_gemma import (
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
if is_torch_greater_or_equal("2.5"):
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
|
||||
|
||||
@@ -207,13 +209,166 @@ class Gemma2MLP(nn.Module):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Gemma2Attention(GemmaAttention):
|
||||
class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
key_states = repeat_kv(key, config.num_key_value_groups)
|
||||
value_states = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling
|
||||
|
||||
if config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * config.attn_logit_softcapping
|
||||
if mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
|
||||
if mask is not None:
|
||||
seq_len = mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
value = value[:, :, :seq_len]
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
|
||||
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
|
||||
query_states = query.transpose(1, 2)
|
||||
key_states = key.transpose(1, 2)
|
||||
value_states = value.transpose(1, 2)
|
||||
|
||||
dropout_rate = config.attention_dropout if config.training else 0.0
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
mask,
|
||||
seq_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=config.scaling,
|
||||
is_causal=config.is_causal,
|
||||
sliding_window=config.sliding_window,
|
||||
use_top_left_mask=config._flash_attn_uses_top_left_mask,
|
||||
softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
|
||||
def tanh_softcap(score, b, h, q_idx, kv_idx):
|
||||
soft_cap = config.attn_logit_softcapping
|
||||
score = soft_cap * torch.tanh(score / soft_cap)
|
||||
if mask is not None:
|
||||
return score + mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output = flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=tanh_softcap,
|
||||
enable_gqa=True,
|
||||
scale=config.scaling,
|
||||
return_lse=output_attentions,
|
||||
)
|
||||
if not output_attentions:
|
||||
return attn_output, None
|
||||
else:
|
||||
return attn_output[0], attn_output[1]
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
key = repeat_kv(key, config.num_key_value_groups)
|
||||
value = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
causal_mask = mask
|
||||
if mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query.device.type == "cuda" and causal_mask is not None:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=config.attention_dropout if config.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=config.scaling,
|
||||
)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
GEMMA2_ATTENTION_FUNCTION = {
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
"eager": eager_attention_forward,
|
||||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
|
||||
|
||||
class Gemma2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
self.attn_logit_softcapping = config.attn_logit_softcapping
|
||||
if self.hidden_size % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
self.rotary_emb = Gemma2RotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -248,145 +403,14 @@ class Gemma2Attention(GemmaAttention):
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Gemma2FlashAttention2(Gemma2Attention):
|
||||
"""
|
||||
Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask.shape[1]
|
||||
key_states = key_states[:, :, :seq_len]
|
||||
value_states = value_states[:, :, :seq_len]
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (Gemma2RMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
|
||||
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
|
||||
attention_type = "eager"
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
attention_type = self.config._attn_implementation
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
is_causal=self.is_causal,
|
||||
sliding_window=self.sliding_window,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
|
||||
attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
|
||||
self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
@@ -398,105 +422,37 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Gemma2SdpaAttention(Gemma2Attention):
|
||||
"""
|
||||
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Gemma2Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class Gemma2DecoderLayer(GemmaDecoderLayer):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
class Gemma2FlashAttention2(Gemma2Attention):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.config._attn_implementation = "flash_attention_2"
|
||||
logger.warning_once(
|
||||
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
|
||||
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
|
||||
)
|
||||
|
||||
|
||||
class Gemma2SdpaAttention(Gemma2Attention):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.config._attn_implementation = "sdpa"
|
||||
logger.warning_once(
|
||||
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
|
||||
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
|
||||
)
|
||||
|
||||
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.config = config
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma2MLP(config)
|
||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@@ -209,6 +209,7 @@ from .import_utils import (
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_fx_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_mlu_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_musa_available,
|
||||
|
||||
@@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str):
|
||||
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_torch_greater_or_equal(library_version: str):
|
||||
if not _is_package_available("torch"):
|
||||
return False
|
||||
|
||||
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
|
||||
|
||||
|
||||
def is_torchdistx_available():
|
||||
return _torchdistx_available
|
||||
|
||||
|
||||
@@ -1496,7 +1496,7 @@ class GenerationTesterMixin:
|
||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# They should result in very similar logits
|
||||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_past_key_values_format(self):
|
||||
|
||||
@@ -199,19 +199,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_sdpa_equivalence(self):
|
||||
pass
|
||||
|
||||
def test_eager_attention_loaded_by_default(self):
|
||||
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Usually we enable SDPA by default, but not for Gemma2
|
||||
model = Gemma2Model(config)
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
|
||||
# We can still force SDPA
|
||||
config._attn_implementation = "sdpa"
|
||||
model = Gemma2Model(config)
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -277,9 +264,30 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
"Hi today I'm going to be talking about the history of the United States. The United States of America",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
|
||||
output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)
|
||||
|
||||
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
|
||||
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_pipeline_bf16_flex_attention(self):
|
||||
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
|
||||
model_id = "google/gemma-2-2b"
|
||||
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
|
||||
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
|
||||
@@ -365,3 +373,23 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
@require_read_token
|
||||
def test_model_9b_bf16_flex_attention(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@@ -153,9 +153,9 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
|
||||
# Handle ClassB.call_to_method
|
||||
if (
|
||||
isinstance(original_node.value, cst.Name)
|
||||
m.matches(original_node.value, m.Name())
|
||||
and original_node.value.value in self.all_bases
|
||||
and isinstance(original_node.attr, cst.Name)
|
||||
and m.matches(original_node.attr, m.Name())
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(
|
||||
@@ -163,10 +163,10 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
)
|
||||
# Handle ClassB().call_to_method
|
||||
elif (
|
||||
isinstance(original_node.value, cst.Call)
|
||||
and isinstance(original_node.value.func, cst.Name)
|
||||
m.matches(original_node.value, m.Call())
|
||||
and m.matches(original_node.value.func, m.Name())
|
||||
and original_node.value.func.value in self.all_bases
|
||||
and isinstance(original_node.attr, cst.Name)
|
||||
and m.matches(original_node.attr, m.Name())
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
|
||||
@@ -174,16 +174,16 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
|
||||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
|
||||
if isinstance(original_node.func, cst.Attribute) and (
|
||||
if m.matches(original_node.func, m.Attribute()) and (
|
||||
# Match ClassB().func_a(...)
|
||||
(
|
||||
isinstance(original_node.func.value, cst.Call)
|
||||
and isinstance(original_node.func.value.func, cst.Name)
|
||||
m.matches(original_node.func.value, m.Call())
|
||||
and m.matches(original_node.func.value.func, m.Name())
|
||||
and original_node.func.value.func.value in self.all_bases
|
||||
)
|
||||
or
|
||||
# Match ClassB.func_a(...)
|
||||
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
|
||||
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
|
||||
):
|
||||
# Check if the first argument is 'self', and remove it
|
||||
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
|
||||
@@ -632,8 +632,10 @@ class ModuleMapper(CSTVisitor, ABC):
|
||||
for id, node in self.global_nodes.items():
|
||||
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
|
||||
|
||||
# Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that
|
||||
# are not part of the recorded objects (i.e. built-in variables, imports, etc)
|
||||
def _restrict_dependencies_to_known_entities(self):
|
||||
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
|
||||
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
|
||||
This should be called only after all merging operations have been finalized!!"""
|
||||
global_objects = set(self.global_nodes.keys())
|
||||
for object_name, dependencies in self.object_dependency_mapping.items():
|
||||
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
|
||||
@@ -814,6 +816,8 @@ class ModelFileMapper(ModuleMapper):
|
||||
# Correctly re-set the global nodes at this point
|
||||
self.global_nodes.update(self.functions)
|
||||
self.global_nodes.update(self.assignments)
|
||||
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
|
||||
self._restrict_dependencies_to_known_entities()
|
||||
# Create the global mapping of recursive dependencies for functions and assignments
|
||||
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
||||
|
||||
@@ -1142,22 +1146,20 @@ class ModularFileMapper(ModuleMapper):
|
||||
if assigned_variable == "__all__":
|
||||
self.all_all_to_add = split_all_assignment(node)
|
||||
else:
|
||||
self.current_assignment = assigned_variable
|
||||
self.assignments[assigned_variable] = node
|
||||
|
||||
def leave_Module(self, node):
|
||||
"""When we leave the modular file, we do the following in order:
|
||||
1. compute the nested (recursive) function and assignment dependencies
|
||||
2. for each modeling file found in the imports, rename it with the new model name, visit it, and update
|
||||
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
|
||||
its dependency graph with the new function and assignment definitions found in the modular
|
||||
3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
|
||||
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
|
||||
3. compute the nested (recursive) function and assignment dependencies
|
||||
"""
|
||||
# Takes care of finalizing our visit
|
||||
super().leave_Module(node)
|
||||
|
||||
# 1. compute the nested (recursive) function and assignment dependencies
|
||||
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
||||
|
||||
# 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
|
||||
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
|
||||
self.visited_modules = {}
|
||||
self.renamers = {}
|
||||
for file, module in self.model_specific_modules.items():
|
||||
@@ -1177,10 +1179,13 @@ class ModularFileMapper(ModuleMapper):
|
||||
# We record it so that we can rename classes later the exact same way
|
||||
self.renamers[file] = renamer
|
||||
|
||||
# 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
|
||||
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
|
||||
# definitions found in the visited files
|
||||
self.merge_model_specific_imports(self.visited_modules)
|
||||
|
||||
# 3. compute the nested (recursive) function and assignment dependencies
|
||||
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
||||
|
||||
# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
|
||||
# Note that we may visit several of the same file types, thus we save them per file type, not file
|
||||
self.imported_objects_per_file = defaultdict(set)
|
||||
@@ -1200,9 +1205,9 @@ class ModularFileMapper(ModuleMapper):
|
||||
if object_name in visited_module.functions and object_name not in self.functions:
|
||||
self.functions[object_name] = visited_module.functions[object_name]
|
||||
self.added_objects_file_mapping[object_name] = file
|
||||
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
|
||||
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
|
||||
if dependencies is not None:
|
||||
self.object_recursive_dependency_mapping[object_name] = dependencies
|
||||
self.object_dependency_mapping[object_name] = dependencies
|
||||
for dep in dependencies:
|
||||
if dep not in self.global_nodes:
|
||||
self.added_objects_file_mapping[dep] = file
|
||||
@@ -1212,9 +1217,9 @@ class ModularFileMapper(ModuleMapper):
|
||||
elif object_name in visited_module.assignments and object_name not in self.assignments:
|
||||
self.assignments[object_name] = visited_module.assignments[object_name]
|
||||
self.added_objects_file_mapping[object_name] = file
|
||||
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
|
||||
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
|
||||
if dependencies is not None:
|
||||
self.object_recursive_dependency_mapping[object_name] = dependencies
|
||||
self.object_dependency_mapping[object_name] = dependencies
|
||||
for dep in dependencies:
|
||||
if dep not in self.global_nodes:
|
||||
self.added_objects_file_mapping[dep] = file
|
||||
@@ -1222,6 +1227,8 @@ class ModularFileMapper(ModuleMapper):
|
||||
|
||||
# Do not forget to re-assign all nodes after the merge
|
||||
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
|
||||
# And restric dependencies to those nodes only
|
||||
self._restrict_dependencies_to_known_entities()
|
||||
|
||||
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
|
||||
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
|
||||
@@ -1239,10 +1246,11 @@ class ModularFileMapper(ModuleMapper):
|
||||
else:
|
||||
original_dependencies.append(dep)
|
||||
# Sort all lists according to the order in their respective file
|
||||
all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
|
||||
all_dependencies = []
|
||||
for file, dependencies in other_files_dependencies.items():
|
||||
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
|
||||
all_dependencies += sorted_dependencies
|
||||
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])
|
||||
|
||||
# Add all original node first, then merged ones (one file at a time)
|
||||
for dep in all_dependencies:
|
||||
@@ -1485,7 +1493,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["src/transformers/models/gemma/modular_gemma.py"],
|
||||
default=["src/transformers/models/gemma2/modular_gemma2.py"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user