diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7672df0b9a..d68166d526 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1519,6 +1519,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "eager", "sdpa", "flash_attention_2", + "flex_attention", ]: message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c439ec069f..6111261830 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -41,7 +41,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, logging, replace_return_docstrings, ) @@ -51,6 +51,8 @@ from .configuration_gemma2 import Gemma2Config if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention logger = logging.get_logger(__name__) @@ -168,6 +170,127 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None + + +GEMMA2_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -175,12 +298,6 @@ class Gemma2Attention(nn.Module): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -192,7 +309,8 @@ class Gemma2Attention(nn.Module): self.rope_theta = config.rope_theta self.is_causal = True self.scaling = config.query_pre_attn_scalar**-0.5 - + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -208,7 +326,6 @@ class Gemma2Attention(nn.Module): max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -243,145 +360,14 @@ class Gemma2Attention(nn.Module): } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") + attention_type = "eager" + else: + attention_type = self.config._attn_implementation - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - """ - Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if attention_mask is not None: - seq_len = attention_mask.shape[1] - key_states = key_states[:, :, :seq_len] - value_states = value_states[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Gemma2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -393,116 +379,37 @@ class Gemma2FlashAttention2(Gemma2Attention): return attn_output, attn_weights, past_key_value -class Gemma2SdpaAttention(Gemma2Attention): - """ - Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Gemma2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, +class Gemma2FlashAttention2(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "flash_attention_2" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -GEMMA2_ATTENTION_CLASSES = { - "eager": Gemma2Attention, - "flash_attention_2": Gemma2FlashAttention2, - "sdpa": Gemma2SdpaAttention, -} +class Gemma2SdpaAttention(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "sdpa" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.config = config + self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.config = config - self.is_sliding = not bool(layer_idx % 2) + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window @@ -517,25 +424,6 @@ class Gemma2DecoderLayer(nn.Module): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor if self.config._attn_implementation == "flash_attention_2": diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index ff2d42d671..8d86238632 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -29,18 +29,17 @@ from ...modeling_outputs import ( from ...utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, logging, ) from ..gemma.modeling_gemma import ( - GemmaAttention, - GemmaDecoderLayer, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, GemmaRMSNorm, + GemmaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) @@ -49,6 +48,9 @@ from ..gemma.modeling_gemma import ( if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + _CHECKPOINT_FOR_DOC = "google/gemma2-7b" @@ -207,13 +209,166 @@ class Gemma2MLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -class Gemma2Attention(GemmaAttention): +class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): + pass + + +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None + + +GEMMA2_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.attn_logit_softcapping = config.attn_logit_softcapping + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = Gemma2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) def forward( self, @@ -248,145 +403,14 @@ class Gemma2Attention(GemmaAttention): } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") + attention_type = "eager" + else: + attention_type = self.config._attn_implementation - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - """ - Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if attention_mask is not None: - seq_len = attention_mask.shape[1] - key_states = key_states[:, :, :seq_len] - value_states = value_states[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Gemma2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -398,105 +422,37 @@ class Gemma2FlashAttention2(Gemma2Attention): return attn_output, attn_weights, past_key_value -class Gemma2SdpaAttention(Gemma2Attention): - """ - Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Gemma2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, +class Gemma2FlashAttention2(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "flash_attention_2" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -class Gemma2DecoderLayer(GemmaDecoderLayer): - def __init__(self, config: Gemma2Config, layer_idx: int): +class Gemma2SdpaAttention(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) + self.config._attn_implementation = "sdpa" + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) + + +class Gemma2DecoderLayer(nn.Module): + def __init__(self, config: Gemma2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size self.config = config self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2a10bcaa3c..492642d61b 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -209,6 +209,7 @@ from .import_utils import ( is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_fx_proxy, + is_torch_greater_or_equal, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 173aee9b1a..6306efa2fa 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) +@lru_cache() +def is_torch_greater_or_equal(library_version: str): + if not _is_package_available("torch"): + return False + + return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) + + def is_torchdistx_available(): return _torchdistx_available diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6630fc2ba9..b1d0042c65 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1496,7 +1496,7 @@ class GenerationTesterMixin: next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] # They should result in very similar logits - self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5)) + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-5) @pytest.mark.generate def test_past_key_values_format(self): diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 7bca83f96d..06116c4dba 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -199,19 +199,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): def test_sdpa_equivalence(self): pass - def test_eager_attention_loaded_by_default(self): - """Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default.""" - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - # Usually we enable SDPA by default, but not for Gemma2 - model = Gemma2Model(config) - self.assertTrue(model.config._attn_implementation == "eager") - - # We can still force SDPA - config._attn_implementation = "sdpa" - model = Gemma2Model(config) - self.assertTrue(model.config._attn_implementation == "sdpa") - @slow @require_torch_gpu @@ -277,9 +264,30 @@ class Gemma2IntegrationTest(unittest.TestCase): "Hi today I'm going to be talking about the history of the United States. The United States of America", ] - model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( - torch_device - ) + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) + + output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True) + + self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0]) + self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1]) + + @require_read_token + def test_model_2b_pipeline_bf16_flex_attention(self): + # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR + model_id = "google/gemma-2-2b" + # EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1960s and I am trying to find out what the average", + "Hi today I'm going to be talking about the 10 best anime of all time.\n\n1", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) @@ -365,3 +373,23 @@ class Gemma2IntegrationTest(unittest.TestCase): ) ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + @require_read_token + def test_model_9b_bf16_flex_attention(self): + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + + self.assertEqual(output_text, EXPECTED_TEXTS) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index b1dfa18a7a..e5f6e34ece 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -153,9 +153,9 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer): def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: # Handle ClassB.call_to_method if ( - isinstance(original_node.value, cst.Name) + m.matches(original_node.value, m.Name()) and original_node.value.value in self.all_bases - and isinstance(original_node.attr, cst.Name) + and m.matches(original_node.attr, m.Name()) ): # Replace with super().call_to_method return updated_node.with_changes( @@ -163,10 +163,10 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer): ) # Handle ClassB().call_to_method elif ( - isinstance(original_node.value, cst.Call) - and isinstance(original_node.value.func, cst.Name) + m.matches(original_node.value, m.Call()) + and m.matches(original_node.value.func, m.Name()) and original_node.value.func.value in self.all_bases - and isinstance(original_node.attr, cst.Name) + and m.matches(original_node.attr, m.Name()) ): # Replace with super().call_to_method return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super")))) @@ -174,16 +174,16 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer): def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: # Check if the function being called is of the form ClassB().func_a or ClassB.func_a - if isinstance(original_node.func, cst.Attribute) and ( + if m.matches(original_node.func, m.Attribute()) and ( # Match ClassB().func_a(...) ( - isinstance(original_node.func.value, cst.Call) - and isinstance(original_node.func.value.func, cst.Name) + m.matches(original_node.func.value, m.Call()) + and m.matches(original_node.func.value.func, m.Name()) and original_node.func.value.func.value in self.all_bases ) or # Match ClassB.func_a(...) - (isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases) + (m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases) ): # Check if the first argument is 'self', and remove it if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): @@ -632,8 +632,10 @@ class ModuleMapper(CSTVisitor, ABC): for id, node in self.global_nodes.items(): self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line - # Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that - # are not part of the recorded objects (i.e. built-in variables, imports, etc) + def _restrict_dependencies_to_known_entities(self): + """Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that + are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc). + This should be called only after all merging operations have been finalized!!""" global_objects = set(self.global_nodes.keys()) for object_name, dependencies in self.object_dependency_mapping.items(): self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} @@ -814,6 +816,8 @@ class ModelFileMapper(ModuleMapper): # Correctly re-set the global nodes at this point self.global_nodes.update(self.functions) self.global_nodes.update(self.assignments) + # Restrict the dependency mappings to the know entities to avoid Python's built-ins + self._restrict_dependencies_to_known_entities() # Create the global mapping of recursive dependencies for functions and assignments self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() @@ -1142,22 +1146,20 @@ class ModularFileMapper(ModuleMapper): if assigned_variable == "__all__": self.all_all_to_add = split_all_assignment(node) else: + self.current_assignment = assigned_variable self.assignments[assigned_variable] = node def leave_Module(self, node): """When we leave the modular file, we do the following in order: - 1. compute the nested (recursive) function and assignment dependencies - 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update + 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update its dependency graph with the new function and assignment definitions found in the modular - 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + 2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + 3. compute the nested (recursive) function and assignment dependencies """ # Takes care of finalizing our visit super().leave_Module(node) - # 1. compute the nested (recursive) function and assignment dependencies - self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() - - # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies + # 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies self.visited_modules = {} self.renamers = {} for file, module in self.model_specific_modules.items(): @@ -1177,10 +1179,13 @@ class ModularFileMapper(ModuleMapper): # We record it so that we can rename classes later the exact same way self.renamers[file] = renamer - # 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the # definitions found in the visited files self.merge_model_specific_imports(self.visited_modules) + # 3. compute the nested (recursive) function and assignment dependencies + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + # We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later # Note that we may visit several of the same file types, thus we save them per file type, not file self.imported_objects_per_file = defaultdict(set) @@ -1200,9 +1205,9 @@ class ModularFileMapper(ModuleMapper): if object_name in visited_module.functions and object_name not in self.functions: self.functions[object_name] = visited_module.functions[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.object_dependency_mapping.get(object_name, None) if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies + self.object_dependency_mapping[object_name] = dependencies for dep in dependencies: if dep not in self.global_nodes: self.added_objects_file_mapping[dep] = file @@ -1212,9 +1217,9 @@ class ModularFileMapper(ModuleMapper): elif object_name in visited_module.assignments and object_name not in self.assignments: self.assignments[object_name] = visited_module.assignments[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.object_dependency_mapping.get(object_name, None) if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies + self.object_dependency_mapping[object_name] = dependencies for dep in dependencies: if dep not in self.global_nodes: self.added_objects_file_mapping[dep] = file @@ -1222,6 +1227,8 @@ class ModularFileMapper(ModuleMapper): # Do not forget to re-assign all nodes after the merge self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # And restric dependencies to those nodes only + self._restrict_dependencies_to_known_entities() def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that @@ -1239,10 +1246,11 @@ class ModularFileMapper(ModuleMapper): else: original_dependencies.append(dep) # Sort all lists according to the order in their respective file - all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + all_dependencies = [] for file, dependencies in other_files_dependencies.items(): sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) all_dependencies += sorted_dependencies + all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x]) # Add all original node first, then merged ones (one file at a time) for dep in all_dependencies: @@ -1485,7 +1493,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma/modular_gemma.py"], + default=["src/transformers/models/gemma2/modular_gemma2.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", )