From e0aad278fe5cd6feba126e37c4514a1e5a6377ba Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 23 May 2025 19:48:01 +0200 Subject: [PATCH] Never fallback to eager implicitly (#38327) * remove arg everywhere * Update warnings * add more models * Update sdpa_attention.py * fix style * fix * readd warnings but not for flex * Update test_modeling_common.py * skip * fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/modular_transformers.md | 8 +---- src/transformers/generation/utils.py | 2 -- .../integrations/flex_attention.py | 5 ++- src/transformers/masking_utils.py | 36 ++----------------- src/transformers/models/aria/modeling_aria.py | 1 - .../models/bitnet/modeling_bitnet.py | 9 +---- .../models/bitnet/modular_bitnet.py | 8 +---- .../models/cohere/modeling_cohere.py | 9 +---- .../models/cohere/modular_cohere.py | 8 +---- .../models/cohere2/modeling_cohere2.py | 9 +---- .../models/cohere2/modular_cohere2.py | 9 +---- src/transformers/models/csm/modeling_csm.py | 2 -- src/transformers/models/csm/modular_csm.py | 1 - .../deepseek_v3/modeling_deepseek_v3.py | 9 +---- .../models/deepseek_v3/modular_deepseek_v3.py | 8 +---- .../models/diffllama/modeling_diffllama.py | 18 ---------- .../models/diffllama/modular_diffllama.py | 17 --------- src/transformers/models/emu3/modeling_emu3.py | 1 - .../models/falcon_h1/modeling_falcon_h1.py | 9 +---- .../models/falcon_h1/modular_falcon_h1.py | 9 +---- .../models/gemma/modeling_gemma.py | 1 - .../models/gemma/modular_gemma.py | 1 - .../models/gemma2/modeling_gemma2.py | 9 +---- .../models/gemma2/modular_gemma2.py | 9 +---- .../models/gemma3/modeling_gemma3.py | 13 +------ .../models/gemma3/modular_gemma3.py | 13 +------ src/transformers/models/glm/modeling_glm.py | 1 - src/transformers/models/glm4/modeling_glm4.py | 1 - .../models/gpt_neox/modeling_gpt_neox.py | 24 ++----------- .../models/gpt_neox/modular_gpt_neox.py | 24 ++----------- .../models/granite/modeling_granite.py | 1 - .../models/granite/modular_granite.py | 1 - .../models/granitemoe/modeling_granitemoe.py | 8 +---- .../modeling_granitemoehybrid.py | 8 +---- .../modeling_granitemoeshared.py | 8 +---- .../models/helium/modeling_helium.py | 1 - .../models/internvl/modeling_internvl.py | 12 +------ .../models/internvl/modular_internvl.py | 8 +---- .../models/janus/modeling_janus.py | 8 +---- .../models/janus/modular_janus.py | 8 +---- .../models/llama/modeling_llama.py | 1 - .../models/llama4/modeling_llama4.py | 17 ++------- .../models/mistral/modeling_mistral.py | 9 +---- .../models/mistral/modular_mistral.py | 9 +---- .../models/mixtral/modeling_mixtral.py | 9 +---- .../models/mixtral/modular_mixtral.py | 1 - src/transformers/models/mlcd/modeling_mlcd.py | 13 ++----- src/transformers/models/mlcd/modular_mlcd.py | 8 +---- .../models/moonshine/modeling_moonshine.py | 9 +---- .../models/moonshine/modular_moonshine.py | 9 +---- src/transformers/models/olmo/modeling_olmo.py | 9 +---- src/transformers/models/olmo/modular_olmo.py | 8 +---- .../models/olmo2/modeling_olmo2.py | 9 +---- .../models/olmo2/modular_olmo2.py | 8 +---- src/transformers/models/phi/modeling_phi.py | 9 +---- src/transformers/models/phi/modular_phi.py | 9 +---- src/transformers/models/phi3/modeling_phi3.py | 9 +---- src/transformers/models/phi3/modular_phi3.py | 8 +---- .../modeling_phi4_multimodal.py | 9 +---- .../modular_phi4_multimodal.py | 1 - .../models/qwen2/modeling_qwen2.py | 9 +---- .../models/qwen2/modular_qwen2.py | 9 +---- .../models/qwen3/modeling_qwen3.py | 9 +---- .../models/qwen3/modular_qwen3.py | 8 +---- .../models/qwen3_moe/modeling_qwen3_moe.py | 9 +---- .../models/starcoder2/modeling_starcoder2.py | 9 +---- .../models/starcoder2/modular_starcoder2.py | 9 +---- .../models/timesfm/modeling_timesfm.py | 8 +---- .../models/timesfm/modular_timesfm.py | 8 +---- .../models/zamba/modeling_zamba.py | 8 +---- .../models/zamba2/modeling_zamba2.py | 9 +---- .../models/zamba2/modular_zamba2.py | 8 +---- tests/test_modeling_common.py | 6 ++-- 73 files changed, 66 insertions(+), 544 deletions(-) diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md index badeab0214..84d365f9aa 100644 --- a/docs/source/en/modular_transformers.md +++ b/docs/source/en/modular_transformers.md @@ -243,13 +243,7 @@ class Olmo2Attention(OlmoAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8ea1e7e760..784b2d15b6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -655,7 +655,6 @@ class GenerationMixin(ContinuousMixin): # If it's not defined, it means the model uses the new general mask API if causal_mask_creation_function is None: # can't be found - output_attentions = kwargs.get("output_attentions", False) token_type_ids = getattr(model_input, "token_type_ids", None) # Some models may overwrite the general one causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate) @@ -666,7 +665,6 @@ class GenerationMixin(ContinuousMixin): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, token_type_ids=token_type_ids, ) else: diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index cc9787657b..1e1228873f 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -235,10 +235,9 @@ def flex_attention_forward( head_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - if kwargs.get("output_attentions", False) or head_mask is not None: + if head_mask is not None: logger.warning_once( - "`flex_attention` does not support `output_attentions=True` or `head_mask`." - " Please set your attention to `eager` if you want any of these features." + "`flex_attention` does not support `head_mask`. Please set your attention to `eager` if you want this feature." ) if kwargs.get("dropout", 0.0) > 0: diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 53a81e1daa..8829c0711a 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -644,30 +644,12 @@ def _preprocess_mask_arguments( return False, attention_mask, kv_length, kv_offset -def _get_mask_interface(config: PretrainedConfig, output_attentions: bool = False) -> Callable: - """ - Return the mask interface (a function) to be used, based on the type of attention found in the config. - - Args: - config (`PretrainedConfig`): - The model config. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. - """ - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] - # Sdpa fallbacks to eager in the Attention modules if `output_attentions=True` - if config._attn_implementation == "sdpa" and output_attentions: - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS["eager"] - return mask_interface - - def create_causal_mask( config: PretrainedConfig, input_embeds: torch.Tensor, attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, "BlockMask"]]: @@ -689,8 +671,6 @@ def create_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the causal one, for example for image tokens handling. @@ -712,7 +692,7 @@ def create_causal_mask( batch_size, dtype = input_embeds.shape[0], input_embeds.dtype mask_factory_function = causal_mask_function - mask_interface = _get_mask_interface(config, output_attentions) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it @@ -751,7 +731,6 @@ def create_sliding_window_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, "BlockMask"]]: @@ -774,8 +753,6 @@ def create_sliding_window_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. @@ -801,7 +778,7 @@ def create_sliding_window_causal_mask( batch_size, dtype = input_embeds.shape[0], input_embeds.dtype mask_factory_function = sliding_window_causal_mask_function(sliding_window) - mask_interface = _get_mask_interface(config, output_attentions) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it @@ -841,7 +818,6 @@ def create_chunked_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, "BlockMask"]]: @@ -864,8 +840,6 @@ def create_chunked_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. @@ -898,7 +872,7 @@ def create_chunked_causal_mask( batch_size, dtype = input_embeds.shape[0], input_embeds.dtype mask_factory_function = chunked_causal_mask_function(chunk_size) - mask_interface = _get_mask_interface(config, output_attentions) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it @@ -945,7 +919,6 @@ def create_masks_for_generate( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, **kwargs, @@ -967,8 +940,6 @@ def create_masks_for_generate( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the other mask function (by doing the union of both). This is useful to easily overlay another mask on top of the causal one, for example for image tokens handling. @@ -985,7 +956,6 @@ def create_masks_for_generate( "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, "or_mask_function": or_mask_function, "and_mask_function": and_mask_function, } diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 8f80a7ff08..abb751ab7d 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -805,7 +805,6 @@ class AriaTextModel(AriaTextPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index e98f9ed116..661a3c9bb6 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -205,13 +205,7 @@ class BitNetAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -425,7 +419,6 @@ class BitNetModel(BitNetPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/bitnet/modular_bitnet.py b/src/transformers/models/bitnet/modular_bitnet.py index 0c0d133cb5..c57b7217f1 100644 --- a/src/transformers/models/bitnet/modular_bitnet.py +++ b/src/transformers/models/bitnet/modular_bitnet.py @@ -85,13 +85,7 @@ class BitNetAttention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 37f698a86e..0700eb8e9f 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -261,13 +261,7 @@ class CohereAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -462,7 +456,6 @@ class CohereModel(CoherePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index a44aebcead..e37c875be3 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -184,13 +184,7 @@ class CohereAttention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 144667f1e3..5690864cfc 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -222,13 +222,7 @@ class Cohere2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -439,7 +433,6 @@ class Cohere2Model(Cohere2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 792d278cc0..7a5cab506e 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -309,13 +309,7 @@ class Cohere2Attention(CohereAttention, nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -461,7 +455,6 @@ class Cohere2Model(Gemma2Model): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 34abf4f15d..e1f1d477b3 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -509,7 +509,6 @@ class CsmDepthDecoderModel(CsmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -821,7 +820,6 @@ class CsmBackboneModel(CsmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 35fdf127fc..aab2d131c4 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -247,7 +247,6 @@ class CsmDepthDecoderModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index b15301e288..5804eeee4b 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -412,13 +412,7 @@ class DeepseekV3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -608,7 +602,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index b4905c6201..e7d5eaded7 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -293,13 +293,7 @@ class DeepseekV3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 84df7b4d41..68aa54180c 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -397,23 +397,6 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -708,7 +691,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index f7bc2d2c5a..b772a9f04d 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -330,23 +330,6 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index cf7004c97f..c13eb25d9a 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1272,7 +1272,6 @@ class Emu3TextModel(Emu3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 3a2e20e7cc..e508db3865 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -383,13 +383,7 @@ class FalconH1Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -1552,7 +1546,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 07b9e54084..540b7e7fee 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -251,13 +251,7 @@ class FalconH1Attention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -1307,7 +1301,6 @@ class FalconH1ForCausalLM(LlamaForCausalLM): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index b7990e9660..2a29608919 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -415,7 +415,6 @@ class GemmaModel(GemmaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # embed positions diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 1a1e8cc1c6..e934df7ef8 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -416,7 +416,6 @@ class GemmaModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # embed positions diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fe5576ae1c..7bb865bc5d 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -218,13 +218,7 @@ class Gemma2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -445,7 +439,6 @@ class Gemma2Model(Gemma2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 7d0b721d80..31b251f4ca 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -283,13 +283,7 @@ class Gemma2Attention(GemmaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -428,7 +422,6 @@ class Gemma2Model(GemmaModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 122d16aafc..4a9fbfcd31 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -345,14 +345,7 @@ class Gemma3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -566,7 +559,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { @@ -949,7 +941,6 @@ class Gemma3Model(Gemma3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` @@ -1200,7 +1191,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1211,7 +1201,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index f0761d863d..7f9d1be8d9 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -424,14 +424,7 @@ class Gemma3Attention(Gemma2Attention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -617,7 +610,6 @@ class Gemma3TextModel(Gemma2Model): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { @@ -840,7 +832,6 @@ class Gemma3Model(PaliGemmaModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` @@ -1050,7 +1041,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1061,7 +1051,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 236b0ed5c4..235f8258c1 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -436,7 +436,6 @@ class GlmModel(GlmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 48d8522502..f32bfb3a39 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -444,7 +444,6 @@ class Glm4Model(Glm4PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9c32acdb06..16de0f23db 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -166,28 +166,9 @@ class GPTNeoXAttention(nn.Module): } key_states, value_states = layer_past.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Checking for fallbacks in case an unsupported feature is requested - attention_type = self.config._attn_implementation - if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - ]: - logger.warning_once( - f"Setting `attention_type` to `eager` because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - attention_type = "eager" - - elif self.training and self.attention_dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Setting `attention_type` to `eager` because `dropout` is not supported in `{attention_type}`." - ) - attention_type = "eager" - attention_interface: Callable = eager_attention_forward - attention_interface = ( - ALL_ATTENTION_FUNCTIONS[attention_type] if attention_type != "eager" else attention_interface - ) + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] # Compute attention attn_output, attn_weights = attention_interface( @@ -409,7 +390,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # Prepare head mask if needed diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 70bee31b28..e7d67a9764 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -153,28 +153,9 @@ class GPTNeoXAttention(nn.Module): } key_states, value_states = layer_past.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Checking for fallbacks in case an unsupported feature is requested - attention_type = self.config._attn_implementation - if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - ]: - logger.warning_once( - f"Setting `attention_type` to `eager` because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - attention_type = "eager" - - elif self.training and self.attention_dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Setting `attention_type` to `eager` because `dropout` is not supported in `{attention_type}`." - ) - attention_type = "eager" - attention_interface: Callable = eager_attention_forward - attention_interface = ( - ALL_ATTENTION_FUNCTIONS[attention_type] if attention_type != "eager" else attention_interface - ) + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] # Compute attention attn_output, attn_weights = attention_interface( @@ -356,7 +337,6 @@ class GPTNeoXModel(LlamaModel, nn.Module): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # Prepare head mask if needed diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 639db77701..11f2873f3d 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -439,7 +439,6 @@ class GraniteModel(GranitePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 424a0cc3fa..33f3b3363e 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -181,7 +181,6 @@ class GraniteModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index fdd7addc45..a3a314a6ab 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -452,13 +452,7 @@ class GraniteMoeAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 48e1dc0020..d6ff36bf32 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -203,13 +203,7 @@ class GraniteMoeHybridAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 0845ba7b69..dc429aa55b 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -387,13 +387,7 @@ class GraniteMoeSharedAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index eea73341d0..b9cb3bafc1 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -421,7 +421,6 @@ class HeliumModel(HeliumPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 26463c2009..6f06b32c16 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -40,16 +40,12 @@ from ...utils import ( auto_docstring, can_return_tuple, is_torchdynamo_compiling, - logging, torch_int, ) from ..auto import AutoModel from .configuration_internvl import InternVLConfig, InternVLVisionConfig -logger = logging.get_logger(__name__) - - @use_kernel_forward_from_hub("RMSNorm") class InternVLVisionRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -151,13 +147,7 @@ class InternVLVisionAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index b4e1efe348..91e53ec523 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -108,13 +108,7 @@ class InternVLVisionAttention(JanusVisionAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index a5ba0337bb..959cdc6856 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -358,13 +358,7 @@ class JanusVisionAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 6ff52d7558..a696568778 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -506,13 +506,7 @@ class JanusVisionAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7b0416ec19..4502cee6e5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -425,7 +425,6 @@ class LlamaModel(LlamaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 82caab17db..2f6202ef8d 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -356,13 +356,7 @@ class Llama4TextAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, @@ -570,7 +564,6 @@ class Llama4TextModel(Llama4PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { @@ -916,13 +909,7 @@ class Llama4VisionAttention(nn.Module): attention_interface: Callable = vision_eager_attention_forward # flex disable because breaks on TP 8, embed is 88 not power of 2 if self.config._attn_implementation not in ["eager", "flex_attention"]: - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 92ef09fc73..90881cbcd2 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -164,13 +164,7 @@ class MistralAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -404,7 +398,6 @@ class MistralModel(MistralPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 9dd2e051b5..e943150f01 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -75,13 +75,7 @@ class MistralAttention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -168,7 +162,6 @@ class MistralModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9f176a35d8..9147538f73 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -276,13 +276,7 @@ class MixtralAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -525,7 +519,6 @@ class MixtralModel(MixtralPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 95d8defddd..7132a4165f 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -370,7 +370,6 @@ class MixtralModel(MistralModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index d4f1441642..ad943fd9c5 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -28,13 +28,10 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, logging, torch_int +from ...utils import auto_docstring, can_return_tuple, torch_int from .configuration_mlcd import MLCDVisionConfig -logger = logging.get_logger(__name__) - - class MLCDMLP(nn.Module): def __init__(self, config): super().__init__() @@ -281,13 +278,7 @@ class MLCDAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index 7e55288b41..6186cbbbb4 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -226,13 +226,7 @@ class MLCDAttention(CLIPAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index edc70eafa0..a3aebaed9a 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -265,13 +265,7 @@ class MoonshineAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False @@ -749,7 +743,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index c22198843c..6abc22ae99 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -361,13 +361,7 @@ class MoonshineAttention(GlmAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False @@ -755,7 +749,6 @@ class MoonshineDecoder(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index fe4e081a3e..36999733b3 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -191,13 +191,7 @@ class OlmoAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -406,7 +400,6 @@ class OlmoModel(OlmoPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 02ff85ac10..18a17533a3 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -111,13 +111,7 @@ class OlmoAttention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index dd2fcdb17f..661a9341d6 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -179,13 +179,7 @@ class Olmo2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -412,7 +406,6 @@ class Olmo2Model(Olmo2PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index a652709862..103d6616c5 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -225,13 +225,7 @@ class Olmo2Attention(OlmoAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4dade5e49e..3f2deffd9e 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -173,13 +173,7 @@ class PhiAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -400,7 +394,6 @@ class PhiModel(PhiPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 6f9edaba94..c1d40774bc 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -96,13 +96,7 @@ class PhiAttention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -251,7 +245,6 @@ class PhiModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 52cf0ef96d..08f93a468b 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -193,13 +193,7 @@ class Phi3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -459,7 +453,6 @@ class Phi3Model(Phi3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 4e38e3164d..9c34138880 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -145,13 +145,7 @@ class Phi3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 6b6cef7df3..858666b9f5 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1427,13 +1427,7 @@ class Phi4MultimodalAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -1766,7 +1760,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 344e6c1776..e6ab2c1cb0 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -1577,7 +1577,6 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 9e9b0641f0..03df9df94f 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -165,13 +165,7 @@ class Qwen2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -411,7 +405,6 @@ class Qwen2Model(Qwen2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 5a24b425f0..10f3c3acca 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -75,13 +75,7 @@ class Qwen2Attention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -173,7 +167,6 @@ class Qwen2Model(MistralModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 48eb9489be..ef1cd22e0c 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -212,13 +212,7 @@ class Qwen3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -437,7 +431,6 @@ class Qwen3Model(Qwen3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 096a0e5b9c..466eb3d029 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -89,13 +89,7 @@ class Qwen3Attention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index f349f2f3d6..77b2362fe1 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -177,13 +177,7 @@ class Qwen3MoeAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -530,7 +524,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 6d79e0f0f7..6e102d8014 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -184,13 +184,7 @@ class Starcoder2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -405,7 +399,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index fd5840e40a..d415870e49 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -103,13 +103,7 @@ class Starcoder2Attention(MistralAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -220,7 +214,6 @@ class Starcoder2Model(MistralModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index b54f7c15b8..63ea960028 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -246,13 +246,7 @@ class TimesFmAttention(nn.Module): attention_interface: Callable = simple_eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 20ee49c01a..8918705f3a 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -202,13 +202,7 @@ class TimesFmAttention(nn.Module): attention_interface: Callable = simple_eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 8e5d1f6567..7733decd01 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -269,13 +269,7 @@ class ZambaAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 5f2c382c67..b2aed16823 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -53,7 +53,6 @@ if is_causal_conv1d_available(): else: causal_conv1d_update, causal_conv1d_fn = None, None - logger = logging.get_logger(__name__) @@ -435,13 +434,7 @@ class Zamba2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index be4f38044c..c4e14dd148 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -270,13 +270,7 @@ class Zamba2Attention(ZambaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7bd494e452..87fab3f8af 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4353,7 +4353,8 @@ class ModelTesterMixin: if hasattr(config, "layer_types"): del config_dict["layer_types"] new_config = config.__class__(**config_dict) - model = model_class(new_config).to(torch_device) + # We need to set eager as otherwise `output_attentions` is not supported + model = model_class._from_config(new_config, attn_implementation="eager").to(torch_device) model.eval() layer_types = getattr(model.config, "layer_types", ["sliding_attention"] * config.num_hidden_layers) attentions = model(**inputs, output_attentions=True).attentions @@ -4370,7 +4371,8 @@ class ModelTesterMixin: if hasattr(config, "layer_types"): del config_dict["layer_types"] new_config = config.__class__(**config_dict) - model = model_class(new_config).to(torch_device) + # We need to set eager as otherwise `output_attentions` is not supported + model = model_class._from_config(new_config, attn_implementation="eager").to(torch_device) model.eval() attentions_not_sliding = model(**inputs, output_attentions=True).attentions for layer_attention in attentions_not_sliding: