diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index d449ca50a2..672adcf149 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -19,7 +19,6 @@ from ...utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_new_task_model import NewTaskModelConfig @@ -328,7 +327,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index abd7d1b8b0..7f88875f3e 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -41,7 +41,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_available from ..auto import AutoModel, AutoModelForCausalLM from .configuration_aria import AriaConfig, AriaTextConfig @@ -1160,7 +1159,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1417,7 +1415,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): return image_features @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 574ee053a9..d331789354 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -49,7 +49,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_available from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer from ..llama.configuration_llama import LlamaConfig @@ -1452,7 +1451,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): return image_features @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 8c1a3b23b9..13e3dfdb43 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -35,7 +35,6 @@ from ...utils import ( is_torchdynamo_compiling, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_aya_vision import AyaVisionConfig @@ -343,7 +342,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi image_features = self.multi_modal_projector(selected_image_feature) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AyaVisionCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 5d70cf0dda..62496ba4eb 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -48,7 +48,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig @@ -1451,7 +1450,6 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 9c52fb16ac..8d64fdd989 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -55,7 +55,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, @@ -1188,7 +1187,6 @@ class BambaModel(BambaPreTrainedModel): class BambaForCausalLM(LlamaForCausalLM): @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 43c5294788..a250d47809 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -51,7 +51,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_cohere import CohereConfig @@ -799,7 +798,6 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 1d0702615e..a7189a1a21 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -42,7 +42,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_cohere2 import Cohere2Config @@ -797,7 +796,6 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 7da0810645..729e877cc8 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -35,7 +35,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_dbrx import DbrxConfig @@ -1273,7 +1272,6 @@ class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): def get_decoder(self) -> DbrxModel: return self.transformer - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index ccddeed663..546c7fed3d 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -31,7 +31,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_deepseek_v3 import DeepseekV3Config @@ -942,7 +941,6 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index f5bed9bf87..1959fac86e 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -58,7 +58,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_diffllama import DiffLlamaConfig @@ -1045,7 +1044,6 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 3b0a4882de..f74b0cacb0 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -47,7 +47,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig @@ -1631,7 +1630,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") def forward( diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 62e95a7f73..c4e35e71d2 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -36,7 +36,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..chameleon.modeling_chameleon import ( ChameleonPreTrainedModel, ChameleonVQVAEEncoderConvDownsample, @@ -1085,7 +1084,6 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): self.model = Emu3TextModel(config) @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") def forward(**super_kwargs): diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 82a4dfee08..86b974570d 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -45,7 +45,6 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_falcon import FalconConfig @@ -1151,7 +1150,6 @@ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings: torch.Tensor): self.lm_head = new_embeddings - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index fedbac3cea..df66bc36e4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -48,7 +48,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_gemma import GemmaConfig @@ -764,7 +763,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 22c4e599c2..f0d340048f 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -47,7 +47,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_gemma2 import Gemma2Config @@ -804,7 +803,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index e921579326..50ca08a3f1 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -45,7 +45,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig @@ -891,7 +890,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1201,7 +1199,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): return image_features @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index afcb72c202..f2e716f216 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -41,7 +41,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( Gemma2Attention, @@ -847,7 +846,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): return causal_mask @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 8b0ccd9c9e..a12057cbb2 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -50,7 +50,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_glm import GlmConfig @@ -780,7 +779,6 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 6356356f3e..8eb015ac00 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -50,7 +50,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_glm4 import Glm4Config @@ -421,12 +420,15 @@ GLM4_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -623,7 +625,7 @@ class Glm4Model(Glm4PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -636,8 +638,7 @@ class Glm4Model(Glm4PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -783,7 +784,6 @@ class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 1d77b8b7ba..1db20f8624 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -44,7 +44,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_granite import GraniteConfig @@ -780,7 +779,6 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GRANITE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 251beacfad..16fcdc3a77 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -50,7 +50,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_helium import HeliumConfig @@ -765,7 +764,6 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(HELIUM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index d23085bd37..186c7be6bb 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -34,7 +34,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig @@ -1293,7 +1292,6 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index c91cc6e299..64d939e7b4 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -47,7 +47,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, @@ -1434,7 +1433,6 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index bf59fd074e..bca39ba83e 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -42,7 +42,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_jetmoe import JetMoeConfig @@ -1259,7 +1258,6 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6898f168a2..a8bebd2a36 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -50,7 +50,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_llama import LlamaConfig @@ -770,7 +769,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index a8b5f074d0..ba5277d4ff 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -32,7 +32,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava import LlavaConfig @@ -313,7 +312,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(selected_image_feature) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 6301402e6e..b3eae9c443 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -35,7 +35,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next import LlavaNextConfig @@ -525,7 +524,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi image_features = torch.split(image_features, image_num_patches, dim=0) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index b4a9c899c9..1965f08bb4 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -39,7 +39,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next_video import LlavaNextVideoConfig @@ -565,7 +564,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene image_features = torch.split(image_features, image_num_patches, dim=0) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 3f77d39c02..ec92838091 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -33,7 +33,6 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_onevision import LlavaOnevisionConfig @@ -586,7 +585,6 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene return video_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) def forward( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index bf2cccb65b..20e8aca622 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -36,7 +36,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_mistral import MistralConfig @@ -769,7 +768,6 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 08a8b7b315..c27a646dd6 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -36,7 +36,6 @@ from ...utils import ( is_torchdynamo_compiling, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_mistral3 import Mistral3Config @@ -364,7 +363,6 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1f3b7bd0a6..377e18b875 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -59,7 +59,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_mixtral import MixtralConfig @@ -983,7 +982,6 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 77bff448ff..db961a30dc 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -36,7 +36,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig @@ -1873,7 +1872,6 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig") def forward( @@ -2018,7 +2016,6 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.language_model.get_decoder() - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig") def forward( diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 403c27f078..6e990c4402 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -48,7 +48,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto.modeling_auto import AutoModel from .configuration_moshi import MoshiConfig, MoshiDepthConfig @@ -1777,7 +1776,6 @@ class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoshiCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 4e816c2a5c..d33cb3a24b 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -47,7 +47,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_nemotron import NemotronConfig @@ -1011,7 +1010,6 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy (doc string different) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 7ac2dd6ad9..aa7d6e7445 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -29,7 +29,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_olmo import OlmoConfig @@ -740,7 +739,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index d09d47c7c7..999b2ded05 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -29,7 +29,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_olmo2 import Olmo2Config @@ -746,7 +745,6 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 007da568f0..429ea28413 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -37,7 +37,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_olmoe import OlmoeConfig @@ -1159,7 +1158,6 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index f13bd0feef..b7ac3b751b 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -33,7 +33,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_paligemma import PaliGemmaConfig @@ -405,7 +404,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 3b492ab851..564865ddef 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -47,7 +47,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_persimmon import PersimmonConfig @@ -817,7 +816,6 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index de1abd9963..8b8bd4b8e8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -34,7 +34,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_phi import PhiConfig @@ -738,7 +737,6 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 2bf32cdcdc..4955857247 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -51,7 +51,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_phi3 import Phi3Config @@ -824,7 +823,6 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 42df5ac9d0..683aa431ce 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -42,7 +42,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_phimoe import PhimoeConfig @@ -1372,7 +1371,6 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8c95569a49..2d24e228c2 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -36,7 +36,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2 import Qwen2Config @@ -782,7 +781,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 4c71c431ac..718b719e58 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -43,6 +43,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -68,6 +69,12 @@ else: apply_rotary_emb = None +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -2035,7 +2042,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -2053,6 +2060,10 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -2747,7 +2758,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -2765,6 +2776,10 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 365b561eb9..d780d6051a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -50,7 +50,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2_moe import Qwen2MoeConfig @@ -1230,7 +1229,6 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 3cc5de2421..8416478e5a 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -51,7 +51,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_qwen3 import Qwen3Config @@ -809,7 +808,6 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 6a062420f2..daf8335acd 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -54,7 +54,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_qwen3_moe import Qwen3MoeConfig @@ -997,7 +996,6 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py index 3981a1f54d..a0a65ea37c 100644 --- a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py +++ b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py @@ -26,7 +26,6 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModelForImageTextToText from .configuration_shieldgemma2 import ShieldGemma2Config @@ -150,7 +149,6 @@ class ShieldGemma2ForImageClassification(PreTrainedModel): def tie_weights(self): return self.model.language_model.tie_weights() - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(SHIELDGEMMA2_INPUTS_DOCSTRING) def forward( self, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 685a7c2372..85e20b4493 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -48,7 +48,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_stablelm import StableLmConfig @@ -1072,7 +1071,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ae16daaa46..0299cab66f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -53,7 +53,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from .configuration_starcoder2 import Starcoder2Config @@ -759,7 +758,6 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 24aaee0351..f092bba196 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -32,7 +32,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_video_llava import VideoLlavaConfig @@ -359,7 +358,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return video_features, num_frames - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 71e5b9498b..c6060b756e 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -32,7 +32,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_vipllava import VipLlavaConfig @@ -288,7 +287,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) image_features = self.multi_modal_projector(image_features) return image_features - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 2fe1720c42..29c54fbdf4 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -48,7 +48,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, @@ -1208,7 +1207,6 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ZAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index a3735303ec..7854b04bc6 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -43,7 +43,6 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config @@ -1617,7 +1616,6 @@ class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward(