From d3af76df58476830eb5b5981decc64af15e369f5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 23 Jan 2025 09:47:54 +0100 Subject: [PATCH] [Backend support] Allow `num_logits_to_keep` as Tensor + add flag (#35757) * support * Update modeling_utils.py * style * most models * Other models * fix-copies * tests + generation utils --- .../generation/candidate_generator.py | 6 ++-- src/transformers/generation/utils.py | 24 +++++++-------- src/transformers/modeling_utils.py | 9 ++++++ src/transformers/models/aria/modeling_aria.py | 28 ++++++++++------- src/transformers/models/aria/modular_aria.py | 17 +++++++---- .../models/bamba/modeling_bamba.py | 19 +++++++----- .../models/bamba/modular_bamba.py | 18 ++++++----- .../models/cohere/modeling_cohere.py | 14 ++++++--- .../models/cohere/modular_cohere.py | 11 ++++--- .../models/cohere2/modeling_cohere2.py | 20 ++++++++----- .../models/cohere2/modular_cohere2.py | 6 ++-- src/transformers/models/dbrx/modeling_dbrx.py | 13 +++++--- .../models/diffllama/modeling_diffllama.py | 14 ++++++--- .../models/diffllama/modular_diffllama.py | 1 + src/transformers/models/emu3/modeling_emu3.py | 25 +++++++++++----- src/transformers/models/emu3/modular_emu3.py | 20 +++++++++---- .../models/falcon/modeling_falcon.py | 13 +++++--- .../models/gemma/modeling_gemma.py | 14 ++++++--- .../models/gemma/modular_gemma.py | 6 ++-- .../models/gemma2/modeling_gemma2.py | 20 ++++++++----- .../models/gemma2/modular_gemma2.py | 11 +++---- src/transformers/models/glm/modeling_glm.py | 14 ++++++--- .../models/granite/modeling_granite.py | 14 ++++++--- .../models/granite/modular_granite.py | 5 ++-- .../models/helium/modeling_helium.py | 14 ++++++--- .../models/idefics2/modeling_idefics2.py | 19 +++++++----- .../models/idefics3/modeling_idefics3.py | 6 ++-- .../models/jamba/modeling_jamba.py | 23 +++++++------- .../models/jetmoe/modeling_jetmoe.py | 13 +++++--- .../models/llama/modeling_llama.py | 14 ++++++--- .../models/llava/modeling_llava.py | 16 ++++++---- .../models/llava_next/modeling_llava_next.py | 16 ++++++---- .../modeling_llava_next_video.py | 16 ++++++---- .../modular_llava_next_video.py | 14 +++++---- .../modeling_llava_onevision.py | 16 ++++++---- .../models/mistral/modeling_mistral.py | 14 ++++++--- .../models/mixtral/modeling_mixtral.py | 14 ++++++--- .../models/mixtral/modular_mixtral.py | 11 ++++--- .../models/mllama/modeling_mllama.py | 30 ++++++++++++------- .../models/moshi/modeling_moshi.py | 17 +++++++---- .../models/nemotron/modeling_nemotron.py | 13 +++++--- src/transformers/models/olmo/modeling_olmo.py | 14 ++++++--- .../models/olmo2/modeling_olmo2.py | 14 ++++++--- .../models/olmoe/modeling_olmoe.py | 15 ++++++---- .../models/paligemma/modeling_paligemma.py | 16 ++++++---- .../models/persimmon/modeling_persimmon.py | 13 +++++--- src/transformers/models/phi/modeling_phi.py | 14 ++++++--- src/transformers/models/phi3/modeling_phi3.py | 18 +++++++---- src/transformers/models/phi3/modular_phi3.py | 4 +-- .../models/phimoe/modeling_phimoe.py | 19 +++++++----- .../models/qwen2/modeling_qwen2.py | 14 ++++++--- .../models/qwen2_moe/modeling_qwen2_moe.py | 13 +++++--- .../models/stablelm/modeling_stablelm.py | 13 +++++--- .../models/starcoder2/modeling_starcoder2.py | 14 ++++++--- .../video_llava/modeling_video_llava.py | 16 ++++++---- .../models/vipllava/modeling_vipllava.py | 16 ++++++---- .../models/zamba/modeling_zamba.py | 19 +++++++----- src/transformers/utils/deprecation.py | 11 +++++-- tests/generation/test_utils.py | 24 +++++++-------- tests/models/bamba/test_modeling_bamba.py | 2 +- tests/test_modeling_common.py | 14 ++++----- tests/utils/test_deprecation.py | 27 ++++++++++++++++- 62 files changed, 603 insertions(+), 315 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ba5d0f0005..9689ca2b52 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -129,9 +129,9 @@ class AssistedCandidateGenerator(CandidateGenerator): value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) ) - # Remove potential default "num_logits_to_keep" key - if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep(): - del assistant_kwargs["num_logits_to_keep"] + # Remove potential default "logits_to_keep" key + if "logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_logits_to_keep(): + del assistant_kwargs["logits_to_keep"] if "assistant_encoder_outputs" in model_kwargs: assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 461d7e1215..94230d1b72 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1780,12 +1780,12 @@ class GenerationMixin: else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) - def _supports_num_logits_to_keep(self) -> bool: + def _supports_logits_to_keep(self) -> bool: """ - Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() + Return True if the current model supports the keyword argument `logits_to_keep` in forward() to save memory. Checking it in this way allows to avoid using a new model attribute. """ - return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) + return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) def _prepare_special_tokens( self, @@ -2066,11 +2066,11 @@ class GenerationMixin: input_ids_length=input_ids_length, ) - # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding # dynamically overrides this value as it can need more than the last token logits - if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: - model_kwargs["num_logits_to_keep"] = 1 + if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs: + model_kwargs["logits_to_keep"] = 1 self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -4236,8 +4236,8 @@ class GenerationMixin: ) model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) - if "num_logits_to_keep" in model_inputs: - model_inputs["num_logits_to_keep"] = candidate_length + 1 + if "logits_to_keep" in model_inputs: + model_inputs["logits_to_keep"] = candidate_length + 1 # 2.2. Run a forward pass on the candidate sequence # prepare variable output controls (note: some models won't accept all output controls) @@ -4575,7 +4575,7 @@ def _split_model_inputs( # ModelOutput object. # bool should not be split but replicated for each split bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"] + keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] num_hidden_layers = config.get_text_config().num_hidden_layers @@ -4595,10 +4595,10 @@ def _split_model_inputs( data_split_list = [ {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) ] - # num_logits_to_keep should be replicated for each split, similar to bool values - if "num_logits_to_keep" in model_input: + # logits_to_keep should be replicated for each split, similar to bool values + if "logits_to_keep" in model_input: data_split_list = [ - {**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list + {**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list ] # Convert each dictionary in the list to an object of the inferred class diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a1e2db6c08..d869229af9 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1292,6 +1292,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # `config.base_model_tp_plan` during `post_init`. _tp_plan = None + # This flag signal that the model can be used as an efficient backend in TGI and vLLM + # In practice, it means that they support attention interface functions, fully pass the kwargs + # through all modules up to the Attention layer, and can slice logits with Tensor + _supports_attention_backend = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ @@ -5188,6 +5193,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict()) return self._compiled_call + @classmethod + def is_backend_compatible(cls): + return cls._supports_attention_backend + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 0b330b4aee..4143016735 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -37,6 +37,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_available from ..auto import AutoModel, AutoModelForCausalLM from .configuration_aria import AriaConfig, AriaTextConfig @@ -708,6 +709,7 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = False def _init_weights(self, module): std = self.config.initializer_range @@ -1168,6 +1170,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1183,7 +1186,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1193,10 +1196,12 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1239,7 +1244,8 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1324,8 +1330,9 @@ ARIA_INPUTS_DOCSTRING = r""" Whether to output hidden states. return_dict (`bool`, *optional*): Whether to return a `ModelOutput` object. - num_logits_to_keep (`int`, *optional*, defaults to 0): - Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`. + Otherwise, slice according to the 1D tensor in the sequence length dimension cache_position (`torch.LongTensor`, *optional*): Cache positions. **loss_kwargs: @@ -1426,6 +1433,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( @@ -1442,7 +1450,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, ) -> Union[Tuple, AriaCausalLMOutputWithPast]: @@ -1552,7 +1560,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -1584,7 +1592,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): pixel_mask=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): model_inputs = self.language_model.prepare_inputs_for_generation( @@ -1593,7 +1601,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 295e2dcb74..5c40473a18 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -45,6 +45,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_available from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer from ..llama.configuration_llama import LlamaConfig @@ -1222,6 +1223,8 @@ class AriaTextPreTrainedModel(PreTrainedModel): class AriaPreTrainedModel(LlamaPreTrainedModel): + _supports_attention_backend = False + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): @@ -1301,8 +1304,9 @@ ARIA_INPUTS_DOCSTRING = r""" Whether to output hidden states. return_dict (`bool`, *optional*): Whether to return a `ModelOutput` object. - num_logits_to_keep (`int`, *optional*, defaults to 0): - Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`. + Otherwise, slice according to the 1D tensor in the sequence length dimension cache_position (`torch.LongTensor`, *optional*): Cache positions. **loss_kwargs: @@ -1403,6 +1407,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( @@ -1419,7 +1424,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, ) -> Union[Tuple, AriaCausalLMOutputWithPast]: @@ -1529,7 +1534,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -1561,7 +1566,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): pixel_mask=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): model_inputs = self.language_model.prepare_inputs_for_generation( @@ -1570,7 +1575,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 20a3247be2..edfc162a03 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_2_ssm_available, @@ -1466,6 +1467,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1481,7 +1483,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1491,10 +1493,12 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1537,7 +1541,8 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1602,7 +1607,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 7fb35f48fb..93fb274e4d 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -54,6 +54,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, @@ -1182,6 +1183,7 @@ class BambaModel(BambaPreTrainedModel): class BambaForCausalLM(LlamaForCausalLM): + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1197,7 +1199,7 @@ class BambaForCausalLM(LlamaForCausalLM): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1207,10 +1209,12 @@ class BambaForCausalLM(LlamaForCausalLM): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1242,7 +1246,7 @@ class BambaForCausalLM(LlamaForCausalLM): output_hidden_states, return_dict, cache_position, - num_logits_to_keep, + logits_to_keep, **kwargs, ) @@ -1293,7 +1297,7 @@ class BambaForCausalLM(LlamaForCausalLM): "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9c7207adfc..7337ae6acf 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -48,6 +48,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_cohere import CohereConfig @@ -421,6 +422,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -808,6 +810,7 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -823,7 +826,7 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -833,10 +836,12 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -879,7 +884,8 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits * self.logit_scale # main diff from Llama loss = None diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 6ea8fd6c83..17eb3f6a34 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -317,7 +317,7 @@ class CohereForCausalLM(LlamaForCausalLM): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -327,10 +327,12 @@ class CohereForCausalLM(LlamaForCausalLM): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -373,7 +375,8 @@ class CohereForCausalLM(LlamaForCausalLM): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits * self.logit_scale # main diff from Llama loss = None diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 15469577fb..2353601d91 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -39,6 +39,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_cohere2 import Cohere2Config @@ -421,6 +422,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -779,6 +781,7 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -794,7 +797,7 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -804,10 +807,12 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -850,7 +855,8 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits * self.logit_scale # main diff from Llama loss = None @@ -878,7 +884,7 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -933,8 +939,8 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 7020df2702..1e3295fce3 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -544,7 +544,7 @@ class Cohere2ForCausalLM(CohereForCausalLM): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -599,8 +599,8 @@ class Cohere2ForCausalLM(CohereForCausalLM): batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a2373d3454..5ad827689b 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -35,6 +35,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_dbrx import DbrxConfig @@ -1257,6 +1258,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): def get_decoder(self) -> DbrxModel: return self.transformer + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1273,7 +1275,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r"""Forward function for causal language modeling. @@ -1283,10 +1285,12 @@ class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1333,7 +1337,8 @@ class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # No upscaling to float was ever done for Dbrx - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index bba4e64659..c262340aac 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -51,6 +51,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_diffllama import DiffLlamaConfig @@ -599,6 +600,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = False def _init_weights(self, module): std = self.config.initializer_range @@ -1045,6 +1047,7 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1060,7 +1063,7 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1070,10 +1073,12 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1116,7 +1121,8 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 2c8c846706..c6bdf18093 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -432,6 +432,7 @@ class DiffLlamaDecoderLayer(LlamaDecoderLayer): class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): _supports_flex_attn = False + _supports_attention_backend = False class DiffLlamaModel(LlamaModel): diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index b42a222f6c..6944f91b97 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -44,6 +44,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig @@ -1626,6 +1627,7 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") def forward( @@ -1641,7 +1643,7 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1650,10 +1652,13 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1696,7 +1701,8 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1865,7 +1871,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1873,10 +1879,13 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1949,7 +1958,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) return outputs diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index aacf52fe31..01d09b703d 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -36,6 +36,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..chameleon.modeling_chameleon import ( ChameleonPreTrainedModel, ChameleonVQVAEEncoderConvDownsample, @@ -1071,6 +1072,7 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): super().__init__(config) self.model = Emu3TextModel(config) + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") def forward(**super_kwargs): @@ -1080,10 +1082,13 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1177,7 +1182,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1185,10 +1190,13 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1261,7 +1269,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) return outputs diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index c0fad1ab66..f499801d21 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -46,6 +46,7 @@ from ...utils import ( is_flash_attn_greater_or_equal_2_10, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_falcon import FalconConfig @@ -1176,6 +1177,7 @@ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings: torch.Tensor): self.lm_head = new_embeddings + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1196,7 +1198,7 @@ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1204,10 +1206,12 @@ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1227,7 +1231,8 @@ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + lm_logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 66e975edaa..caaf2c60f5 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -46,6 +46,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_gemma import GemmaConfig @@ -387,6 +388,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -777,6 +779,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -792,7 +795,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -802,10 +805,12 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -848,7 +853,8 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 29b6f8a194..9c015d37c2 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -474,10 +474,12 @@ class GemmaForCausalLM(LlamaForCausalLM): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fb7e59051a..9bc0d27816 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -44,6 +44,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_gemma2 import Gemma2Config @@ -417,6 +418,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -781,6 +783,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -796,7 +799,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -806,10 +809,12 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -856,7 +861,8 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) @@ -887,7 +893,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -942,8 +948,8 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 53a947eb95..ffd75fa2f0 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -539,7 +539,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -584,7 +584,8 @@ class Gemma2ForCausalLM(GemmaForCausalLM): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) @@ -615,7 +616,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -670,8 +671,8 @@ class Gemma2ForCausalLM(GemmaForCausalLM): batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 3e5107c561..a3461ffd71 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -46,6 +46,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_glm import GlmConfig @@ -402,6 +403,7 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -787,6 +789,7 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -802,7 +805,7 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -812,10 +815,12 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -858,7 +863,8 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 3c887d3a1b..4549cdd5d7 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -40,6 +40,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_granite import GraniteConfig @@ -402,6 +403,7 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -790,6 +792,7 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GRANITE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -805,7 +808,7 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -815,10 +818,12 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -861,7 +866,8 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits / self.config.logits_scaling # main diff with Llama loss = None diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 698280085f..f23ae4a673 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -245,7 +245,7 @@ class GraniteForCausalLM(LlamaForCausalLM): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -271,7 +271,8 @@ class GraniteForCausalLM(LlamaForCausalLM): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits / self.config.logits_scaling # main diff with Llama loss = None diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 7eed89b4af..71518c4a9a 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -47,6 +47,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_helium import HeliumConfig @@ -389,6 +390,7 @@ class HeliumPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -774,6 +776,7 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(HELIUM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -789,7 +792,7 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -799,10 +802,12 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -845,7 +850,8 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 4e819811a9..3aaf46d63d 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -37,6 +37,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig @@ -1508,6 +1509,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1525,7 +1527,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: r""" Args: @@ -1535,10 +1537,12 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1604,7 +1608,8 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1648,7 +1653,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) pixel_values=None, pixel_attention_mask=None, image_hidden_states=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take @@ -1677,8 +1682,8 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep if image_hidden_states is not None: pixel_values = None diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 31cf1a2e8f..e4cc8bda56 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -1242,7 +1242,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) pixel_values=None, pixel_attention_mask=None, image_hidden_states=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take @@ -1271,8 +1271,8 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep if image_hidden_states is not None: pixel_values = None diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fd6b1bae31..24aeb9890b 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -45,6 +45,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, @@ -1433,9 +1434,9 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy def forward( self, input_ids: torch.LongTensor = None, @@ -1450,7 +1451,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[Union[int, None]] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1460,10 +1461,12 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1510,10 +1513,8 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): ) hidden_states = outputs[0] - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states) - else: - logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1595,7 +1596,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 433ca61fab..fca47eb3fa 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -42,6 +42,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_jetmoe import JetMoeConfig @@ -1274,6 +1275,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1290,7 +1292,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1299,10 +1301,12 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: """ @@ -1329,7 +1333,8 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8cbb12628c..361ae15c31 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -47,6 +47,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_llama import LlamaConfig @@ -391,6 +392,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -776,6 +778,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -791,7 +794,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -801,10 +804,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -847,7 +852,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 93d7465291..fcf016f28f 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -31,6 +31,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava import LlavaConfig @@ -380,6 +381,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): return final_embedding, final_attention_mask, final_labels, position_ids + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -398,7 +400,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" Args: @@ -407,10 +409,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -490,7 +494,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -534,7 +538,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): pixel_values=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -545,7 +549,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 51df47233b..8bff9dc900 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -34,6 +34,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next import LlavaNextConfig @@ -752,6 +753,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -771,7 +773,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: r""" Args: @@ -780,10 +782,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -871,7 +875,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -916,7 +920,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi image_sizes=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -927,7 +931,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 257c81aa8f..c82d52bfda 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -33,6 +33,7 @@ from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next_video import LlavaNextVideoConfig @@ -787,6 +788,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -807,7 +809,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -819,10 +821,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -967,7 +971,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -1014,7 +1018,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene image_sizes=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- extra custom processing @@ -1025,7 +1029,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 89975a745b..580f890b42 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -335,7 +335,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -347,10 +347,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -495,7 +497,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -542,7 +544,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): image_sizes=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- extra custom processing @@ -553,7 +555,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 5c5471479e..f1cf7a6c2d 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -32,6 +32,7 @@ from ...utils import ( add_start_docstrings, logging, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_onevision import LlavaOnevisionConfig @@ -568,6 +569,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene return video_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) def forward( self, @@ -589,7 +591,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: r""" Args: @@ -598,10 +600,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -734,7 +738,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -782,7 +786,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene image_sizes_videos=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -793,7 +797,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 635cda9cc8..cc62d378eb 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -32,6 +32,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_mistral import MistralConfig @@ -363,6 +364,7 @@ class MistralPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -777,6 +779,7 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -792,7 +795,7 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -802,10 +805,12 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -848,7 +853,8 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8cf2d0e8fa..034ddba8c4 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -55,6 +55,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_mixtral import MixtralConfig @@ -485,6 +486,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -996,6 +998,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1012,7 +1015,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1022,10 +1025,12 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1074,7 +1079,8 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index a6069f69b3..a16e4c5a16 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -466,7 +466,7 @@ class MixtralForCausalLM(MistralForCausalLM): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -476,10 +476,12 @@ class MixtralForCausalLM(MistralForCausalLM): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -528,7 +530,8 @@ class MixtralForCausalLM(MistralForCausalLM): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index b40c366a6d..d1f83e13d8 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -35,6 +35,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig @@ -1872,6 +1873,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig") def forward( @@ -1890,7 +1892,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1900,10 +1902,12 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1950,7 +1954,8 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() loss = None if labels is not None: @@ -2014,6 +2019,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.language_model.get_decoder() + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig") def forward( @@ -2034,7 +2040,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -2043,10 +2049,12 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -2140,7 +2148,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): output_attentions=output_attentions, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) return outputs @@ -2158,7 +2166,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): past_key_values=None, use_cache=False, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -2190,8 +2198,8 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index a1c15b7a0b..3796e2dc5f 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -48,6 +48,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto.modeling_auto import AutoModel from .configuration_moshi import MoshiConfig, MoshiDepthConfig @@ -1788,6 +1789,7 @@ class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoshiCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1803,7 +1805,7 @@ class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, MoshiCausalLMOutputWithPast]: r""" Args: @@ -1812,10 +1814,12 @@ class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1861,7 +1865,8 @@ class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin): "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -2446,7 +2451,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, user_delay_pattern_mask=None, moshi_delay_pattern_mask=None, kwargs_depth_decoder=None, @@ -2463,7 +2468,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 54f774f0b9..8ae6e9c77f 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -46,6 +46,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_nemotron import NemotronConfig @@ -1023,6 +1024,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy (doc string different) @@ -1039,7 +1041,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1049,10 +1051,12 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1094,7 +1098,8 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 34e9f7259c..c2e1ae15b4 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -26,6 +26,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_olmo import OlmoConfig @@ -367,6 +368,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -752,6 +754,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -767,7 +770,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -777,10 +780,12 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -823,7 +828,8 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index a6a1926501..163956d61a 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -25,6 +25,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_olmo2 import Olmo2Config @@ -368,6 +369,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -753,6 +755,7 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -768,7 +771,7 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -778,10 +781,12 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -824,7 +829,8 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 5c78138c1a..47126da956 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -38,6 +38,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_olmoe import OlmoeConfig @@ -756,7 +757,6 @@ OLMOE_START_DOCSTRING = r""" "The bare Olmoe Model outputting raw hidden-states without any specific head on top.", OLMOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmoe class OlmoePreTrainedModel(PreTrainedModel): config_class = OlmoeConfig base_model_prefix = "model" @@ -765,7 +765,6 @@ class OlmoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -1186,6 +1185,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1202,7 +1202,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1212,10 +1212,12 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1262,7 +1264,8 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 36a9e59118..5889f92f3c 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -32,6 +32,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_paligemma import PaliGemmaConfig @@ -412,6 +413,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -429,7 +431,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: r""" Args: @@ -438,10 +440,12 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -532,7 +536,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs.logits @@ -581,7 +585,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi attention_mask=None, token_type_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, labels=None, **kwargs, ): @@ -594,7 +598,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi position_ids=position_ids, cache_position=cache_position, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, token_type_ids=token_type_ids, **kwargs, ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8336ab5a2c..d1cb495294 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -46,6 +46,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_persimmon import PersimmonConfig @@ -830,6 +831,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -845,7 +847,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -854,10 +856,12 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -900,7 +904,8 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # No upscaling to float was ever done for Persimmon - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 33439dff75..7d360b1ed4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -31,6 +31,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_phi import PhiConfig @@ -363,6 +364,7 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -750,6 +752,7 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -765,7 +768,7 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -775,10 +778,12 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -821,7 +826,8 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index cf905cb62e..e86e028b40 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -47,6 +47,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_phi3 import Phi3Config @@ -432,6 +433,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True _version = "0.0.5" def _init_weights(self, module): @@ -847,6 +849,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -862,7 +865,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -872,10 +875,12 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -918,7 +923,8 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -945,7 +951,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the @@ -970,7 +976,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 2b1a19be4a..27f7c42f5b 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -275,7 +275,7 @@ class Phi3ForCausalLM(MistralForCausalLM, Phi3PreTrainedModel): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the @@ -300,7 +300,7 @@ class Phi3ForCausalLM(MistralForCausalLM, Phi3PreTrainedModel): cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index b540dd1830..ba4b766507 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -40,6 +40,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_fx_available from .configuration_phimoe import PhimoeConfig @@ -901,7 +902,6 @@ PHIMOE_START_DOCSTRING = r""" "The bare Phimoe Model outputting raw hidden-states without any specific head on top.", PHIMOE_START_DOCSTRING, ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralPreTrainedModel with Mixtral->Phimoe class PhimoePreTrainedModel(PreTrainedModel): config_class = PhimoeConfig base_model_prefix = "model" @@ -910,7 +910,6 @@ class PhimoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -1365,6 +1364,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy @@ -1382,7 +1382,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1392,10 +1392,12 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python @@ -1445,7 +1447,8 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1488,7 +1491,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the @@ -1513,7 +1516,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f8be4e3740..96cd6a6aa3 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -32,6 +32,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2 import Qwen2Config @@ -376,6 +377,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -761,6 +763,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -776,7 +779,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -786,10 +789,12 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -832,7 +837,8 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 0f61323f40..ad61003c86 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -49,6 +49,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2_moe import Qwen2MoeConfig @@ -1247,6 +1248,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1263,7 +1265,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1273,10 +1275,12 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1323,7 +1327,8 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 4cdab6dc4d..55a85a9a1f 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -48,6 +48,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_stablelm import StableLmConfig @@ -1086,6 +1087,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy @@ -1102,7 +1104,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1111,10 +1113,12 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1156,7 +1160,8 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # No upscaling to float was ever done for StableLm - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 500f96e3e3..57898bc8d6 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -51,6 +51,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_starcoder2 import Starcoder2Config @@ -368,6 +369,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -773,6 +775,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -788,7 +791,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -798,10 +801,12 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -844,7 +849,8 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 293fb10ae2..f592da8185 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -31,6 +31,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_video_llava import VideoLlavaConfig @@ -409,6 +410,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return video_features, num_frames + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -428,7 +430,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: r""" Args: @@ -437,10 +439,12 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -579,7 +583,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -625,7 +629,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi pixel_values_videos=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -636,7 +640,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 0daaa8327b..8ef881b771 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -31,6 +31,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_vipllava import VipLlavaConfig @@ -373,6 +374,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) return final_embedding, final_attention_mask, final_labels, position_ids + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy @@ -391,7 +393,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]: r""" Args: @@ -400,10 +402,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -479,7 +483,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -521,7 +525,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) pixel_values=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -532,7 +536,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 761c799bdc..a25cfbc428 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -48,6 +48,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, @@ -1217,6 +1218,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ZAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1232,7 +1234,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1242,10 +1244,12 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1289,7 +1293,8 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1355,7 +1360,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py index e8416c9f11..064decb14d 100644 --- a/src/transformers/utils/deprecation.py +++ b/src/transformers/utils/deprecation.py @@ -19,7 +19,12 @@ from typing import Optional import packaging.version from .. import __version__ -from . import ExplicitEnum +from . import ExplicitEnum, is_torch_available, is_torchdynamo_compiling + + +# This is needed in case we deprecate a kwarg of a function/method being compiled +if is_torch_available(): + import torch # noqa: F401 class Action(ExplicitEnum): @@ -40,6 +45,7 @@ def deprecate_kwarg( ): """ Function or method decorator to notify users about deprecated keyword arguments, replacing them with a new name if specified. + Note that is decorator is `torch.compile`-safe, i.e. it will not cause graph breaks (but no warning will be displayed if compiling). This decorator allows you to: - Notify users when a keyword argument is deprecated. @@ -158,7 +164,8 @@ def deprecate_kwarg( # raise error or notify user if minimum_action == Action.RAISE: raise ValueError(message) - elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS): + # If we are compiling, we do not raise the warning as it would break compilation + elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling(): # DeprecationWarning is ignored by default, so we use FutureWarning instead warnings.warn(message, FutureWarning, stacklevel=2) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ba61d4b436..b47566354b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2029,10 +2029,10 @@ class GenerationTesterMixin: self._check_similar_generate_outputs(dynamic_result, compiled_result) @pytest.mark.generate - def test_generate_methods_with_num_logits_to_keep(self): + def test_generate_methods_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `logits_to_keep` argument.") config, inputs_dict = self.prepare_config_and_inputs_for_generate() config.use_cache = True @@ -2047,17 +2047,17 @@ class GenerationTesterMixin: "do_sample": False, } - # Setting num_logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0) - # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + # Setting logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0) + # By default, logits_to_keep is automatically set to 1 if not provided (new behavior) without_all_logits = model.generate(**inputs_dict, **generation_kwargs) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) @pytest.mark.generate - def test_assisted_decoding_with_num_logits_to_keep(self): + def test_assisted_decoding_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `logits_to_keep` argument.") if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") @@ -2081,9 +2081,9 @@ class GenerationTesterMixin: "output_scores": True, } - # Setting num_logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0) - # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + # Setting logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0) + # By default, logits_to_keep is automatically set to 1 if not provided (new behavior) without_all_logits = model.generate(**inputs_dict, **generation_kwargs) self._check_similar_generate_outputs(with_all_logits, without_all_logits) diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 9356824dab..16be88f949 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -531,7 +531,7 @@ class BambaModelIntegrationTest(unittest.TestCase): # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist if self.cuda_compute_capability_major_version == 8: with torch.no_grad(): - logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits + logits = self.model(input_ids=input_ids, logits_to_keep=40).logits EXPECTED_LOGITS_NO_GRAD = torch.tensor( [ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bde6e07eff..8a14ba6669 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4759,21 +4759,21 @@ class ModelTesterMixin: for name, param in model._orig_mod.named_parameters(): torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4) - def test_forward_with_num_logits_to_keep(self): + def test_forward_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `logits_to_keep` argument.") config, inputs = self.model_tester.prepare_config_and_inputs_for_common() batch_size, sequence_length = inputs["input_ids"].shape vocab_size = config.get_text_config().vocab_size model = model_class(config).to(device=torch_device).eval() - # some models have labels but `num_logits_to_keep` should not be used in train mode + # some models have labels but `logits_to_keep` should not be used in train mode _ = inputs.pop("labels", None) - # num_logits_to_keep=0 is a special case meaning "keep all logits" - all_logits = model(**inputs, num_logits_to_keep=0).logits - last_token_logits = model(**inputs, num_logits_to_keep=1).logits + # logits_to_keep=0 is a special case meaning "keep all logits" + all_logits = model(**inputs, logits_to_keep=0).logits + last_token_logits = model(**inputs, logits_to_keep=1).logits # Assert all shapes are correct self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size)) diff --git a/tests/utils/test_deprecation.py b/tests/utils/test_deprecation.py index e8e7e671ad..bf9f63e070 100644 --- a/tests/utils/test_deprecation.py +++ b/tests/utils/test_deprecation.py @@ -17,10 +17,15 @@ import warnings from parameterized import parameterized -from transformers import __version__ +from transformers import __version__, is_torch_available +from transformers.testing_utils import require_torch_gpu from transformers.utils.deprecation import deprecate_kwarg +if is_torch_available(): + import torch + + INFINITE_VERSION = "9999.0.0" @@ -168,3 +173,23 @@ class DeprecationDecoratorTester(unittest.TestCase): with self.assertWarns(FutureWarning): result = dummy_function(deprecated_name="old_value", new_name="new_value") self.assertEqual(result, "new_value") + + @require_torch_gpu + def test_compile_safe(self): + @deprecate_kwarg("deprecated_factor", new_name="new_factor", version=INFINITE_VERSION) + def dummy_function(new_factor=None, **kwargs): + return new_factor * torch.ones(1, device="cuda") + + compiled_function = torch.compile(dummy_function, fullgraph=True) + + # Check that we can correctly call the compiled function with the old name, without raising errors + out = compiled_function(deprecated_factor=2) + self.assertEqual(out.item(), 2) + + # Check that we can correctly call the compiled function with the new name, without raising errors + out = compiled_function(new_factor=2) + self.assertEqual(out.item(), 2) + + # Check that we can correctly call the compiled function with both names, without raising errors + out = compiled_function(new_factor=2, deprecated_factor=10) + self.assertEqual(out.item(), 2)