Gemma capping (#34282)

* softcapping

* soft cap before the mask

* style

* ...

* super nit

* update

* fixes

* update

* small issue with modular

* fix modular imports

* update

* fixup

* simplify a hell lot

* simplify cleaning imports

* finish fixing

* update our design

* nits

* use a deprecation cycle

* updates

* Fix modular (recursive deps need to always be computed after merges!)

* push

* fix

* update

* fix modular order

* make fix-copies

* updates

* update

* ?

* don't compile for now

* ?

* fix some stuff

* donc!

* fix copies

* update

* fixup

* ?

* fix two tests

* fix?

* for now, don't use head info

* eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :))

* fix-copies

* revert sdpa check

* Apply suggestions from code review

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* rebase, fix-copies and push

* add a slow integration test

* update the test

* fix left padding issue

* fix test

* remove duplicate scaling

* quality

* add a small test and make sure it works

* 2b

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Arthur
2024-11-19 13:52:38 +01:00
committed by GitHub
parent 54739a320e
commit 4bff54f921
8 changed files with 431 additions and 541 deletions

View File

@@ -1519,6 +1519,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:

View File

@@ -41,7 +41,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_greater_or_equal,
logging,
replace_return_docstrings,
)
@@ -51,6 +51,8 @@ from .configuration_gemma2 import Gemma2Config
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention
logger = logging.get_logger(__name__)
@@ -168,6 +170,127 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling
if config.attn_logit_softcapping is not None:
attn_weights = attn_weights / config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * config.attn_logit_softcapping
if mask is not None: # no matter the length, we just slice it
causal_mask = mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
value_states = value.transpose(1, 2)
dropout_rate = config.attention_dropout if config.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
mask,
seq_len,
dropout=dropout_rate,
softmax_scale=config.scaling,
is_causal=config.is_causal,
sliding_window=config.sliding_window,
use_top_left_mask=config._flash_attn_uses_top_left_mask,
softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
)
return attn_output, None
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = config.attn_logit_softcapping
score = soft_cap * torch.tanh(score / soft_cap)
if mask is not None:
return score + mask[b][0][q_idx][kv_idx]
return score
attn_output = flex_attention(
query,
key,
value,
score_mod=tanh_softcap,
enable_gqa=True,
scale=config.scaling,
return_lse=output_attentions,
)
if not output_attentions:
return attn_output, None
else:
return attn_output[0], attn_output[1]
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)
causal_mask = mask
if mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query.device.type == "cuda" and causal_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=config.attention_dropout if config.training else 0.0,
is_causal=is_causal,
scale=config.scaling,
)
return attn_output, None
GEMMA2_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
class Gemma2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -175,12 +298,6 @@ class Gemma2Attention(nn.Module):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
@@ -192,7 +309,8 @@ class Gemma2Attention(nn.Module):
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = config.query_pre_attn_scalar**-0.5
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
self.attn_logit_softcapping = config.attn_logit_softcapping
if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
@@ -208,7 +326,6 @@ class Gemma2Attention(nn.Module):
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
def forward(
self,
@@ -243,145 +360,14 @@ class Gemma2Attention(nn.Module):
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Gemma2FlashAttention2(Gemma2Attention):
"""
Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if attention_mask is not None:
seq_len = attention_mask.shape[1]
key_states = key_states[:, :, :seq_len]
value_states = value_states[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (Gemma2RMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "eager"
else:
target_dtype = self.q_proj.weight.dtype
attention_type = self.config._attn_implementation
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
softmax_scale=self.scaling,
is_causal=self.is_causal,
sliding_window=self.sliding_window,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -393,116 +379,37 @@ class Gemma2FlashAttention2(Gemma2Attention):
return attn_output, attn_weights, past_key_value
class Gemma2SdpaAttention(Gemma2Attention):
"""
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Gemma2Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
class Gemma2FlashAttention2(Gemma2Attention):
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = "flash_attention_2"
logger.warning_once(
"Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scaling,
class Gemma2SdpaAttention(Gemma2Attention):
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = "sdpa"
logger.warning_once(
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
GEMMA2_ATTENTION_CLASSES = {
"eager": Gemma2Attention,
"flash_attention_2": Gemma2FlashAttention2,
"sdpa": Gemma2SdpaAttention,
}
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.config = config
self.is_sliding = not bool(layer_idx % 2)
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.config = config
self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
@@ -517,25 +424,6 @@ class Gemma2DecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":

View File

@@ -29,18 +29,17 @@ from ...modeling_outputs import (
from ...utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_greater_or_equal,
logging,
)
from ..gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaPreTrainedModel,
GemmaRMSNorm,
GemmaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
@@ -49,6 +48,9 @@ from ..gemma.modeling_gemma import (
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
@@ -207,13 +209,166 @@ class Gemma2MLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Gemma2Attention(GemmaAttention):
class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
pass
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling
if config.attn_logit_softcapping is not None:
attn_weights = attn_weights / config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * config.attn_logit_softcapping
if mask is not None: # no matter the length, we just slice it
causal_mask = mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
value_states = value.transpose(1, 2)
dropout_rate = config.attention_dropout if config.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
mask,
seq_len,
dropout=dropout_rate,
softmax_scale=config.scaling,
is_causal=config.is_causal,
sliding_window=config.sliding_window,
use_top_left_mask=config._flash_attn_uses_top_left_mask,
softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
)
return attn_output, None
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = config.attn_logit_softcapping
score = soft_cap * torch.tanh(score / soft_cap)
if mask is not None:
return score + mask[b][0][q_idx][kv_idx]
return score
attn_output = flex_attention(
query,
key,
value,
score_mod=tanh_softcap,
enable_gqa=True,
scale=config.scaling,
return_lse=output_attentions,
)
if not output_attentions:
return attn_output, None
else:
return attn_output[0], attn_output[1]
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)
causal_mask = mask
if mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query.device.type == "cuda" and causal_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=config.attention_dropout if config.training else 0.0,
is_causal=is_causal,
scale=config.scaling,
)
return attn_output, None
GEMMA2_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
class Gemma2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = config.query_pre_attn_scalar**-0.5
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
self.attn_logit_softcapping = config.attn_logit_softcapping
if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.rotary_emb = Gemma2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
@@ -248,145 +403,14 @@ class Gemma2Attention(GemmaAttention):
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Gemma2FlashAttention2(Gemma2Attention):
"""
Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if attention_mask is not None:
seq_len = attention_mask.shape[1]
key_states = key_states[:, :, :seq_len]
value_states = value_states[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (Gemma2RMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "eager"
else:
target_dtype = self.q_proj.weight.dtype
attention_type = self.config._attn_implementation
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
softmax_scale=self.scaling,
is_causal=self.is_causal,
sliding_window=self.sliding_window,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -398,105 +422,37 @@ class Gemma2FlashAttention2(Gemma2Attention):
return attn_output, attn_weights, past_key_value
class Gemma2SdpaAttention(Gemma2Attention):
"""
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Gemma2Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class Gemma2DecoderLayer(GemmaDecoderLayer):
def __init__(self, config: Gemma2Config, layer_idx: int):
class Gemma2FlashAttention2(Gemma2Attention):
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = "flash_attention_2"
logger.warning_once(
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
)
class Gemma2SdpaAttention(Gemma2Attention):
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = "sdpa"
logger.warning_once(
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
)
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.is_sliding = not bool(layer_idx % 2)
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window

View File

@@ -209,6 +209,7 @@ from .import_utils import (
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_greater_or_equal,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,

View File

@@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str):
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
@lru_cache()
def is_torch_greater_or_equal(library_version: str):
if not _is_package_available("torch"):
return False
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
def is_torchdistx_available():
return _torchdistx_available

View File

@@ -1496,7 +1496,7 @@ class GenerationTesterMixin:
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-5)
@pytest.mark.generate
def test_past_key_values_format(self):

View File

@@ -199,19 +199,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_sdpa_equivalence(self):
pass
def test_eager_attention_loaded_by_default(self):
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Usually we enable SDPA by default, but not for Gemma2
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "eager")
# We can still force SDPA
config._attn_implementation = "sdpa"
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "sdpa")
@slow
@require_torch_gpu
@@ -277,9 +264,30 @@ class Gemma2IntegrationTest(unittest.TestCase):
"Hi today I'm going to be talking about the history of the United States. The United States of America",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
@require_read_token
def test_model_2b_pipeline_bf16_flex_attention(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-2b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
@@ -365,3 +373,23 @@ class Gemma2IntegrationTest(unittest.TestCase):
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
@require_read_token
def test_model_9b_bf16_flex_attention(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
self.assertEqual(output_text, EXPECTED_TEXTS)

View File

@@ -153,9 +153,9 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
m.matches(original_node.value, m.Name())
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(
@@ -163,10 +163,10 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
m.matches(original_node.value, m.Call())
and m.matches(original_node.value.func, m.Name())
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
@@ -174,16 +174,16 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
if m.matches(original_node.func, m.Attribute()) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
m.matches(original_node.func.value, m.Call())
and m.matches(original_node.func.value.func, m.Name())
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
@@ -632,8 +632,10 @@ class ModuleMapper(CSTVisitor, ABC):
for id, node in self.global_nodes.items():
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
# Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that
# are not part of the recorded objects (i.e. built-in variables, imports, etc)
def _restrict_dependencies_to_known_entities(self):
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
This should be called only after all merging operations have been finalized!!"""
global_objects = set(self.global_nodes.keys())
for object_name, dependencies in self.object_dependency_mapping.items():
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
@@ -814,6 +816,8 @@ class ModelFileMapper(ModuleMapper):
# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
@@ -1142,22 +1146,20 @@ class ModularFileMapper(ModuleMapper):
if assigned_variable == "__all__":
self.all_all_to_add = split_all_assignment(node)
else:
self.current_assignment = assigned_variable
self.assignments[assigned_variable] = node
def leave_Module(self, node):
"""When we leave the modular file, we do the following in order:
1. compute the nested (recursive) function and assignment dependencies
2. for each modeling file found in the imports, rename it with the new model name, visit it, and update
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
its dependency graph with the new function and assignment definitions found in the modular
3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
3. compute the nested (recursive) function and assignment dependencies
"""
# Takes care of finalizing our visit
super().leave_Module(node)
# 1. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
# 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
self.visited_modules = {}
self.renamers = {}
for file, module in self.model_specific_modules.items():
@@ -1177,10 +1179,13 @@ class ModularFileMapper(ModuleMapper):
# We record it so that we can rename classes later the exact same way
self.renamers[file] = renamer
# 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# definitions found in the visited files
self.merge_model_specific_imports(self.visited_modules)
# 3. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
# Note that we may visit several of the same file types, thus we save them per file type, not file
self.imported_objects_per_file = defaultdict(set)
@@ -1200,9 +1205,9 @@ class ModularFileMapper(ModuleMapper):
if object_name in visited_module.functions and object_name not in self.functions:
self.functions[object_name] = visited_module.functions[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
@@ -1212,9 +1217,9 @@ class ModularFileMapper(ModuleMapper):
elif object_name in visited_module.assignments and object_name not in self.assignments:
self.assignments[object_name] = visited_module.assignments[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
@@ -1222,6 +1227,8 @@ class ModularFileMapper(ModuleMapper):
# Do not forget to re-assign all nodes after the merge
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
# And restric dependencies to those nodes only
self._restrict_dependencies_to_known_entities()
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
@@ -1239,10 +1246,11 @@ class ModularFileMapper(ModuleMapper):
else:
original_dependencies.append(dep)
# Sort all lists according to the order in their respective file
all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
all_dependencies = []
for file, dependencies in other_files_dependencies.items():
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
all_dependencies += sorted_dependencies
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])
# Add all original node first, then merged ones (one file at a time)
for dep in all_dependencies:
@@ -1485,7 +1493,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/gemma/modular_gemma.py"],
default=["src/transformers/models/gemma2/modular_gemma2.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)