From a1e389e63780ab278e96ff09c90b178dbec3bb5d Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 31 Mar 2025 16:23:37 +0100 Subject: [PATCH] Refactor `return_dict` logic to remove complicated if/else paths (#36794) * SAM * CLIP * SigLIP * GOT-OCR2 (depends on SAM) * SigLIP2 (depends on SigLIP) * trigger tests * Fix SAM * Fix missed indexing, use named attributes * Llama * Aria * Bamba * Update llama: missed outputs return type * (fixup) Aria * DiffLlama * Emu3 * Gemma * Gemma2 * Paligemma * Fix paligemma * Gemma3 * GLM * Helium * JetMoe * Jamba * Mistral * Mistral * Mixtral * Nemotron * Olmo * Olmo2 * Persimmon * Phi * Phi3 * PhiMoe * Qwen2 * Qwen2_moe * StableLM * Starcoder2 * Add return_dict decorator * SAM * Update decorator: compile, export, trace - friendly * Llama (decorator) * SAM (decorator) * Add decorator `can_return_tuple` * Llama * Update to decorator * Update CLIP * Update decorator to store `_is_top_level_module` in self * Update decorator to correctly handle compile/export * Remove is_torchdynamo_compiling constraint, all work fine with self attribute assignment * Typing * GPT NeoX * Fixup * Fix attribute Granite * Fix return type mixtral * Update Gemma3 * Fix Cohere amd Cohere2 * Fixup * Fix corner case for Phi4, when activation is shared * (fix-copies) deepseekv3, phi4 * Fixup * Apply to qwen3/qwen3_moe * Fix --- src/transformers/models/aria/modeling_aria.py | 37 ++--- src/transformers/models/aria/modular_aria.py | 16 +-- .../models/bamba/modeling_bamba.py | 23 +-- .../models/bamba/modular_bamba.py | 14 +- src/transformers/models/clip/modeling_clip.py | 131 ++++++------------ .../models/cohere/modeling_cohere.py | 23 ++- .../models/cohere/modular_cohere.py | 15 +- .../models/cohere2/modeling_cohere2.py | 23 ++- .../models/cohere2/modular_cohere2.py | 5 +- .../deepseek_v3/modeling_deepseek_v3.py | 23 ++- .../models/diffllama/modeling_diffllama.py | 65 +++------ src/transformers/models/emu3/modeling_emu3.py | 33 ++--- src/transformers/models/emu3/modular_emu3.py | 13 +- .../models/gemma/modeling_gemma.py | 51 +++---- .../models/gemma/modular_gemma.py | 5 +- .../models/gemma2/modeling_gemma2.py | 51 +++---- .../models/gemma2/modular_gemma2.py | 16 +-- .../models/gemma3/modeling_gemma3.py | 34 ++--- .../models/gemma3/modular_gemma3.py | 18 +-- src/transformers/models/glm/modeling_glm.py | 51 +++---- .../models/got_ocr2/modeling_got_ocr2.py | 34 ++--- .../models/got_ocr2/modular_got_ocr2.py | 21 +-- .../models/gpt_neox/modeling_gpt_neox.py | 63 +++------ .../models/gpt_neox/modular_gpt_neox.py | 61 +++----- .../models/granite/modeling_granite.py | 23 ++- .../models/granite/modular_granite.py | 16 +-- .../models/helium/modeling_helium.py | 51 +++---- .../models/jamba/modeling_jamba.py | 47 ++----- .../models/jetmoe/modeling_jetmoe.py | 42 ++---- .../models/llama/modeling_llama.py | 65 +++------ .../models/mistral/modeling_mistral.py | 65 +++------ .../models/mistral/modular_mistral.py | 13 +- .../models/mixtral/modeling_mixtral.py | 70 +++------- .../models/mixtral/modular_mixtral.py | 25 +--- .../models/moonshine/modeling_moonshine.py | 47 +++---- .../models/moonshine/modular_moonshine.py | 46 ++---- .../models/nemotron/modeling_nemotron.py | 64 +++------ src/transformers/models/olmo/modeling_olmo.py | 23 ++- .../models/olmo2/modeling_olmo2.py | 23 ++- .../models/paligemma/modeling_paligemma.py | 11 +- .../models/persimmon/modeling_persimmon.py | 51 ++----- src/transformers/models/phi/modeling_phi.py | 51 +++---- src/transformers/models/phi/modular_phi.py | 5 +- src/transformers/models/phi3/modeling_phi3.py | 51 +++---- .../modeling_phi4_multimodal.py | 47 +++---- .../modular_phi4_multimodal.py | 37 ++--- .../models/phimoe/modeling_phimoe.py | 47 ++----- .../models/qwen2/modeling_qwen2.py | 65 +++------ .../models/qwen2_moe/modeling_qwen2_moe.py | 79 +++-------- .../models/qwen3/modeling_qwen3.py | 65 +++------ .../models/qwen3_moe/modeling_qwen3_moe.py | 70 +++------- .../models/qwen3_moe/modular_qwen3_moe.py | 17 +-- src/transformers/models/sam/modeling_sam.py | 46 ++---- .../models/siglip/modeling_siglip.py | 98 +++++-------- .../models/siglip2/modeling_siglip2.py | 99 +++++-------- .../models/siglip2/modular_siglip2.py | 54 ++------ .../models/stablelm/modeling_stablelm.py | 51 ++----- .../models/starcoder2/modeling_starcoder2.py | 51 +++---- .../models/starcoder2/modular_starcoder2.py | 10 +- src/transformers/utils/__init__.py | 1 + src/transformers/utils/generic.py | 64 +++++++++ tests/utils/test_generic.py | 119 ++++++++++++++++ 62 files changed, 943 insertions(+), 1692 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index d87b8ec0c5..5c3b50caef 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -35,6 +35,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -895,6 +896,7 @@ class AriaTextModel(AriaTextPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) def forward( self, @@ -906,16 +908,14 @@ class AriaTextModel(AriaTextPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -998,13 +998,12 @@ class AriaTextModel(AriaTextPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -1182,6 +1181,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1196,11 +1196,10 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1236,10 +1235,9 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1248,12 +1246,11 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1262,10 +1259,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1445,6 +1438,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) @@ -1461,11 +1455,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, - ) -> Union[Tuple, AriaCausalLMOutputWithPast]: + ) -> AriaCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1531,7 +1524,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -1562,7 +1554,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - outputs = self.language_model( + outputs: CausalLMOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1570,12 +1562,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, logits_to_keep=logits_to_keep, cache_position=cache_position, ) - logits = outputs[0] + logits = outputs.logits loss = None if labels is not None: @@ -1583,10 +1574,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return AriaCausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index d341970e29..afc61a02dd 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -33,6 +33,7 @@ from ...image_utils import ( valid_images, validate_preprocess_arguments, ) +from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils import ( @@ -43,6 +44,7 @@ from ...utils import ( TensorType, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -1416,6 +1418,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) @@ -1432,11 +1435,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, - ) -> Union[Tuple, AriaCausalLMOutputWithPast]: + ) -> AriaCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1502,7 +1504,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -1533,7 +1534,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - outputs = self.language_model( + outputs: CausalLMOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1541,12 +1542,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, logits_to_keep=logits_to_keep, cache_position=cache_position, ) - logits = outputs[0] + logits = outputs.logits loss = None if labels is not None: @@ -1554,10 +1554,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return AriaCausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 06986e834c..953e1242c5 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -43,6 +43,7 @@ from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -1191,6 +1192,7 @@ class BambaModel(BambaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) def forward( self, @@ -1202,18 +1204,15 @@ class BambaModel(BambaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, # NOOP kwargs, for now - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1298,8 +1297,6 @@ class BambaModel(BambaPreTrainedModel): next_cache = None if not use_cache else past_key_values - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1471,6 +1468,7 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1485,11 +1483,10 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1525,10 +1522,9 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1537,12 +1533,11 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1551,10 +1546,6 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index f7fb333142..df4f020057 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -51,6 +51,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -935,6 +936,7 @@ class BambaModel(BambaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) def forward( self, @@ -946,18 +948,15 @@ class BambaModel(BambaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, # NOOP kwargs, for now - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1042,8 +1041,6 @@ class BambaModel(BambaPreTrainedModel): next_cache = None if not use_cache else past_key_values - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1184,6 +1181,7 @@ class BambaModel(BambaPreTrainedModel): class BambaForCausalLM(LlamaForCausalLM): + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1198,11 +1196,10 @@ class BambaForCausalLM(LlamaForCausalLM): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1244,7 +1241,6 @@ class BambaForCausalLM(LlamaForCausalLM): use_cache, output_attentions, output_hidden_states, - return_dict, cache_position, logits_to_keep, **kwargs, diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index f6da323a8a..b06b6fdcf5 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -15,7 +15,7 @@ """PyTorch CLIP model.""" from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple import torch import torch.utils.checkpoint @@ -33,6 +33,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, torch_int, @@ -819,6 +820,7 @@ class CLIPEncoder(nn.Module): self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -826,8 +828,7 @@ class CLIPEncoder(nn.Module): causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -861,7 +862,6 @@ class CLIPEncoder(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -894,10 +894,10 @@ class CLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) @@ -916,6 +916,7 @@ class CLIPTextTransformer(nn.Module): # For attention mask, it differs between `flash_attention_2` and other attention implementations self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) def forward( @@ -925,8 +926,7 @@ class CLIPTextTransformer(nn.Module): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -935,7 +935,6 @@ class CLIPTextTransformer(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is None: raise ValueError("You have to specify input_ids") @@ -956,16 +955,15 @@ class CLIPTextTransformer(nn.Module): # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.final_layer_norm(last_hidden_state) if self.eos_token_id == 2: @@ -990,9 +988,6 @@ class CLIPTextTransformer(nn.Module): .argmax(dim=-1), ] - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1022,6 +1017,7 @@ class CLIPTextModel(CLIPPreTrainedModel): def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) def forward( @@ -1031,8 +1027,7 @@ class CLIPTextModel(CLIPPreTrainedModel): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1050,7 +1045,6 @@ class CLIPTextModel(CLIPPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.text_model( input_ids=input_ids, @@ -1058,7 +1052,6 @@ class CLIPTextModel(CLIPPreTrainedModel): position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) @@ -1073,6 +1066,7 @@ class CLIPVisionTransformer(nn.Module): self.encoder = CLIPEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) def forward( @@ -1080,9 +1074,8 @@ class CLIPVisionTransformer(nn.Module): pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1091,7 +1084,6 @@ class CLIPVisionTransformer(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -1099,20 +1091,16 @@ class CLIPVisionTransformer(nn.Module): hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1139,6 +1127,7 @@ class CLIPVisionModel(CLIPPreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) def forward( @@ -1147,8 +1136,7 @@ class CLIPVisionModel(CLIPPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1171,13 +1159,11 @@ class CLIPVisionModel(CLIPPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) @@ -1230,7 +1216,6 @@ class CLIPModel(CLIPPreTrainedModel): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: @@ -1253,18 +1238,16 @@ class CLIPModel(CLIPPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - pooled_output = text_outputs[1] + pooled_output = text_outputs.pooler_output text_features = self.text_projection(pooled_output) return text_features @@ -1276,7 +1259,6 @@ class CLIPModel(CLIPPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, - return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: @@ -1305,21 +1287,20 @@ class CLIPModel(CLIPPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, ) - pooled_output = vision_outputs[1] # pooled_output + pooled_output = vision_outputs.pooler_output image_features = self.visual_projection(pooled_output) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig) def forward( @@ -1332,8 +1313,7 @@ class CLIPModel(CLIPPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CLIPOutput]: + ) -> CLIPOutput: r""" Returns: @@ -1363,29 +1343,26 @@ class CLIPModel(CLIPPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, ) - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - image_embeds = vision_outputs[1] + image_embeds = vision_outputs.pooler_output image_embeds = self.visual_projection(image_embeds) - text_embeds = text_outputs[1] + text_embeds = text_outputs.pooler_output text_embeds = self.text_projection(text_embeds) # normalized features @@ -1402,10 +1379,6 @@ class CLIPModel(CLIPPreTrainedModel): if return_loss: loss = clip_loss(logits_per_text) - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - return CLIPOutput( loss=loss, logits_per_image=logits_per_image, @@ -1445,6 +1418,7 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel): def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig) def forward( @@ -1454,8 +1428,7 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CLIPTextModelOutput]: + ) -> CLIPTextModelOutput: r""" Returns: @@ -1472,25 +1445,17 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel): >>> outputs = model(**inputs) >>> text_embeds = outputs.text_embeds ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - - pooled_output = text_outputs[1] - + pooled_output = text_outputs.pooler_output text_embeds = self.text_projection(pooled_output) - if not return_dict: - outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] - return tuple(output for output in outputs if output is not None) - return CLIPTextModelOutput( text_embeds=text_embeds, last_hidden_state=text_outputs.last_hidden_state, @@ -1523,6 +1488,7 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig) def forward( @@ -1531,8 +1497,7 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CLIPVisionModelOutput]: + ) -> CLIPVisionModelOutput: r""" Returns: @@ -1554,24 +1519,16 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel): >>> outputs = model(**inputs) >>> image_embeds = outputs.image_embeds ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, ) - - pooled_output = vision_outputs[1] # pooled_output - + pooled_output = vision_outputs.pooler_output image_embeds = self.visual_projection(pooled_output) - if not return_dict: - outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:] - return tuple(output for output in outputs if output is not None) - return CLIPVisionModelOutput( image_embeds=image_embeds, last_hidden_state=vision_outputs.last_hidden_state, @@ -1605,6 +1562,7 @@ class CLIPForImageClassification(CLIPPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_IMAGE_CLASS_CHECKPOINT, @@ -1618,8 +1576,7 @@ class CLIPForImageClassification(CLIPPreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, ImageClassifierOutput]: + ) -> ImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., @@ -1630,16 +1587,14 @@ class CLIPForImageClassification(CLIPPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.vision_model( + outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state # average pool the patch tokens sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1) @@ -1671,10 +1626,6 @@ class CLIPForImageClassification(CLIPPreTrainedModel): loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return ImageClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 60adcf89af..eadc89697c 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -46,6 +46,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -545,6 +546,7 @@ class CohereModel(CoherePreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) def forward( self, @@ -556,16 +558,14 @@ class CohereModel(CoherePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -648,13 +648,12 @@ class CohereModel(CoherePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -822,6 +821,7 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -836,11 +836,10 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -876,10 +875,9 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -888,12 +886,11 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -903,10 +900,6 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index a39489a346..4c1ac5ff33 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -30,7 +30,7 @@ from torch import nn from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -315,11 +315,10 @@ class CohereForCausalLM(LlamaForCausalLM): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -355,10 +354,9 @@ class CohereForCausalLM(LlamaForCausalLM): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -367,12 +365,11 @@ class CohereForCausalLM(LlamaForCausalLM): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -382,10 +379,6 @@ class CohereForCausalLM(LlamaForCausalLM): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index be51a992a8..21489e9b78 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -37,6 +37,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -552,6 +553,7 @@ class Cohere2Model(Cohere2PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) def forward( self, @@ -563,17 +565,15 @@ class Cohere2Model(Cohere2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -669,13 +669,12 @@ class Cohere2Model(Cohere2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( @@ -808,6 +807,7 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -822,11 +822,10 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -862,10 +861,9 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -874,12 +872,11 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -889,10 +886,6 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index ce092545f1..d76a92accb 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -462,7 +462,6 @@ class Cohere2Model(Gemma2Model): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], @@ -472,7 +471,6 @@ class Cohere2Model(Gemma2Model): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -568,13 +566,12 @@ class Cohere2Model(Gemma2Model): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class Cohere2ForCausalLM(CohereForCausalLM): diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 24870d2f69..67564fbca4 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -25,6 +25,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -691,6 +692,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) def forward( self, @@ -702,16 +704,14 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -794,13 +794,12 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -966,6 +965,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -980,11 +980,10 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1020,10 +1019,9 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1032,12 +1030,11 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1046,10 +1043,6 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 7cadf5673f..933cf15963 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -52,6 +52,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -787,6 +788,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) def forward( self, @@ -798,16 +800,14 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -890,13 +890,12 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -1062,6 +1061,7 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1076,11 +1076,10 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1116,10 +1115,9 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1128,12 +1126,11 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1142,10 +1139,6 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1186,6 +1179,7 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) def forward( self, @@ -1198,17 +1192,15 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1217,9 +1209,8 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1249,10 +1240,6 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1286,6 +1273,7 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel): def set_input_embeddings(self, value): self.transformer.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) def forward( self, @@ -1298,9 +1286,8 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1311,9 +1298,8 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1321,10 +1307,9 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1335,10 +1320,6 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, @@ -1378,6 +1359,7 @@ class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1395,17 +1377,15 @@ class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1414,9 +1394,8 @@ class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1424,10 +1403,6 @@ class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 43996b4132..607ac4e6b5 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -41,6 +41,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -1370,6 +1371,7 @@ class Emu3TextModel(Emu3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) def forward( self, @@ -1381,16 +1383,14 @@ class Emu3TextModel(Emu3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1473,13 +1473,12 @@ class Emu3TextModel(Emu3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -1646,6 +1645,7 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") @@ -1660,11 +1660,10 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1700,10 +1699,9 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1712,12 +1710,11 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1726,10 +1723,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1873,6 +1866,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): image = self.vqmodel.decode(image_tokens) return image + @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1887,11 +1881,10 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1946,7 +1939,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( @@ -1965,7 +1957,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.text_model( + return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1974,13 +1966,10 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, ) - return outputs - def prepare_inputs_for_generation( self, input_ids, diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 5fc8e43afa..d411ade67a 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -32,6 +32,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -1055,6 +1056,7 @@ class Emu3TextModel(LlamaModel, Emu3PreTrainedModel): [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) def forward(self, **super_kwargs): super().forward(**super_kwargs) @@ -1067,6 +1069,7 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): super().__init__(config) self.model = Emu3TextModel(config) + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") @@ -1160,6 +1163,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): image = self.vqmodel.decode(image_tokens) return image + @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1174,11 +1178,10 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1233,7 +1236,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( @@ -1252,7 +1254,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.text_model( + return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1261,13 +1263,10 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, ) - return outputs - def prepare_inputs_for_generation( self, input_ids, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9fcd9b27da..dfa7aabfcf 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -43,6 +43,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -510,6 +511,7 @@ class GemmaModel(GemmaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) def forward( self, @@ -521,16 +523,14 @@ class GemmaModel(GemmaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, # NOOP kwarg for now - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -615,13 +615,12 @@ class GemmaModel(GemmaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -787,6 +786,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -801,11 +801,10 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -841,10 +840,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -853,12 +851,11 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -867,10 +864,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -911,6 +904,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) def forward( self, @@ -923,17 +917,15 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -942,9 +934,8 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -974,10 +965,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1017,6 +1004,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1034,17 +1022,15 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1053,9 +1039,8 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1063,10 +1048,6 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 564576be76..50b1d33dcc 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -377,7 +377,6 @@ class GemmaModel(LlamaModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, # NOOP kwarg for now ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -386,7 +385,6 @@ class GemmaModel(LlamaModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -471,13 +469,12 @@ class GemmaModel(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class GemmaForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6b23f26208..75e318009c 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -42,6 +42,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -555,6 +556,7 @@ class Gemma2Model(Gemma2PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) def forward( self, @@ -566,17 +568,15 @@ class Gemma2Model(Gemma2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -681,13 +681,12 @@ class Gemma2Model(Gemma2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( @@ -815,6 +814,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -829,11 +829,10 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -875,9 +874,8 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -886,12 +884,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **loss_kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -904,10 +901,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1005,6 +998,7 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) def forward( self, @@ -1017,17 +1011,15 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1036,9 +1028,8 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1068,10 +1059,6 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1111,6 +1098,7 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1128,17 +1116,15 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1147,9 +1133,8 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1157,10 +1142,6 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 06f09fab10..4f8d2e1ba4 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -412,7 +412,6 @@ class Gemma2Model(GemmaModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], @@ -422,7 +421,6 @@ class Gemma2Model(GemmaModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -527,13 +525,12 @@ class Gemma2Model(GemmaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( @@ -588,7 +585,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, @@ -634,9 +630,8 @@ class Gemma2ForCausalLM(GemmaForCausalLM): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -645,12 +640,11 @@ class Gemma2ForCausalLM(GemmaForCausalLM): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **loss_kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -663,10 +657,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM): if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 7c5cd254ff..e17009f402 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -39,6 +39,7 @@ from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -643,6 +644,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) def forward( self, @@ -654,17 +656,15 @@ class Gemma3TextModel(Gemma3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -770,13 +770,12 @@ class Gemma3TextModel(Gemma3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( @@ -906,6 +905,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -920,11 +920,10 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -966,9 +965,8 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -977,12 +975,11 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **loss_kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -995,10 +992,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1222,6 +1215,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(vision_outputs) return image_features + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1239,7 +1233,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: @@ -1304,7 +1297,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict is_training = token_type_ids is not None and labels is not None @@ -1358,7 +1350,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) - outputs = self.language_model( + outputs: CausalLMOutputWithPast = self.language_model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1366,13 +1358,12 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + logits = outputs.logits loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues @@ -1394,9 +1385,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output return Gemma3CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 4a16c52d50..001bbd8f19 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -28,6 +28,7 @@ from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, + CausalLMOutputWithPast, ModelOutput, ) from ...modeling_rope_utils import rope_config_validation @@ -35,6 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import ( add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -592,7 +594,6 @@ class Gemma3TextModel(Gemma2Model): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], @@ -602,7 +603,6 @@ class Gemma3TextModel(Gemma2Model): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -708,13 +708,12 @@ class Gemma3TextModel(Gemma2Model): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class Gemma3ForCausalLM(Gemma2ForCausalLM): @@ -849,6 +848,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): return causal_mask + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -866,7 +866,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: @@ -931,7 +930,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict is_training = token_type_ids is not None and labels is not None @@ -985,7 +983,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) - outputs = self.language_model( + outputs: CausalLMOutputWithPast = self.language_model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -993,13 +991,12 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + logits = outputs.logits loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues @@ -1021,9 +1018,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output return Gemma3CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 716c97de3f..8bd9031127 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -44,6 +44,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -526,6 +527,7 @@ class GlmModel(GlmPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) def forward( self, @@ -537,16 +539,14 @@ class GlmModel(GlmPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -629,13 +629,12 @@ class GlmModel(GlmPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -801,6 +800,7 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -815,11 +815,10 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -855,10 +854,9 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -867,12 +865,11 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -881,10 +878,6 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -925,6 +918,7 @@ class GlmForSequenceClassification(GlmPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) def forward( self, @@ -937,17 +931,15 @@ class GlmForSequenceClassification(GlmPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -956,9 +948,8 @@ class GlmForSequenceClassification(GlmPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -988,10 +979,6 @@ class GlmForSequenceClassification(GlmPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1031,6 +1018,7 @@ class GlmForTokenClassification(GlmPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1048,17 +1036,15 @@ class GlmForTokenClassification(GlmPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1067,9 +1053,8 @@ class GlmForTokenClassification(GlmPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1077,10 +1062,6 @@ class GlmForTokenClassification(GlmPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index fce0ae86e0..83fa36f0e1 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -28,11 +28,18 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers.modeling_outputs import CausalLMOutputWithPast + from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + replace_return_docstrings, +) from ..auto import AutoModelForCausalLM from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig @@ -438,18 +445,17 @@ class GotOcr2VisionEncoder(nn.Module): def get_input_embeddings(self): return self.patch_embed + @can_return_tuple def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, GotOcr2VisionEncoderOutput]: + ) -> GotOcr2VisionEncoderOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -483,14 +489,6 @@ class GotOcr2VisionEncoder(nn.Module): hidden_states = self.neck(hidden_states) - if not return_dict: - outputs = (hidden_states,) - if output_hidden_states: - outputs = outputs + (all_hidden_states,) - if output_attentions: - outputs = outputs + (all_self_attentions,) - return outputs - return GotOcr2VisionEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, @@ -738,6 +736,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) + @can_return_tuple @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -752,7 +751,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]: @@ -805,7 +803,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -831,7 +828,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - outputs = self.language_model( + outputs: CausalLMOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -839,12 +836,11 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, ) - logits = outputs[0] + logits = outputs.logits loss = None if labels is not None: @@ -864,10 +860,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return GotOcr2CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index db7ca86631..aed41cc285 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -14,12 +14,13 @@ # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch import torch.nn as nn import torch.utils.checkpoint +from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, @@ -30,6 +31,7 @@ from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention from ...configuration_utils import PretrainedConfig from ...utils import ( add_start_docstrings_to_model_forward, + can_return_tuple, is_vision_available, logging, replace_return_docstrings, @@ -226,9 +228,6 @@ class GotOcr2Config(PretrainedConfig): super().__init__(**kwargs) -__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"] - - class GotOcr2MLPBlock(SamMLPBlock): pass @@ -381,6 +380,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) + @can_return_tuple @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -395,10 +395,9 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + ) -> LlavaCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -448,7 +447,6 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -474,7 +472,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - outputs = self.language_model( + outputs: CausalLMOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -482,12 +480,11 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, ) - logits = outputs[0] + logits = outputs.logits loss = None if labels is not None: @@ -507,10 +504,6 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return GotOcr2CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 4c2d8d5755..11dfbab89d 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -29,6 +29,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -501,6 +502,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): def set_input_embeddings(self, value): self.embed_in = value + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -519,15 +521,13 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): @@ -618,13 +618,12 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -781,6 +780,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.embed_out = new_embeddings + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -795,7 +795,6 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], @@ -827,9 +826,8 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): >>> prediction_logits = outputs.logits ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -839,12 +837,11 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.embed_out(hidden_states[:, slice_indices, :]) @@ -853,10 +850,6 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -891,6 +884,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -909,17 +903,15 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -929,9 +921,8 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state logits = self.score(hidden_states) batch_size = logits.shape[0] @@ -957,10 +948,6 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -982,6 +969,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint="LarsJonasson/pythia-410m-deduped-sft-swedish", @@ -1002,17 +990,15 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -1022,10 +1008,9 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) @@ -1033,10 +1018,6 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1062,6 +1043,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1081,8 +1063,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1093,9 +1074,8 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1103,10 +1083,9 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1117,10 +1096,6 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 8e64fd120b..1399f1f18f 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -22,6 +22,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -321,6 +322,7 @@ class GPTNeoXModel(LlamaModel, nn.Module): def set_input_embeddings(self, value): self.embed_in = value + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -339,7 +341,6 @@ class GPTNeoXModel(LlamaModel, nn.Module): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -347,7 +348,6 @@ class GPTNeoXModel(LlamaModel, nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): @@ -438,13 +438,12 @@ class GPTNeoXModel(LlamaModel, nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) - return output if return_dict else output.to_tuple() class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -473,6 +472,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.embed_out = new_embeddings + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -487,7 +487,6 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], @@ -519,9 +518,8 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): >>> prediction_logits = outputs.logits ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -531,12 +529,11 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.embed_out(hidden_states[:, slice_indices, :]) @@ -545,10 +542,6 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -583,6 +576,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -601,17 +595,15 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -621,9 +613,8 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state logits = self.score(hidden_states) batch_size = logits.shape[0] @@ -649,10 +640,6 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -674,6 +661,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint="LarsJonasson/pythia-410m-deduped-sft-swedish", @@ -694,17 +682,15 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -714,10 +700,9 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) @@ -725,10 +710,6 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -754,6 +735,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -773,8 +755,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -785,9 +766,8 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + outputs: BaseModelOutputWithPast = self.gpt_neox( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -795,10 +775,9 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -809,10 +788,6 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index f25cbe0dac..6b64e18aa7 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -38,6 +38,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -527,6 +528,7 @@ class GraniteModel(GranitePreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(GRANITE_INPUTS_DOCSTRING) def forward( self, @@ -538,16 +540,14 @@ class GraniteModel(GranitePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -628,13 +628,12 @@ class GraniteModel(GranitePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -800,6 +799,7 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GRANITE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -814,11 +814,10 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -854,10 +853,9 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -866,12 +864,11 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -881,10 +878,6 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 3781ea47ad..f6d99e1c30 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -130,7 +130,6 @@ class GraniteModel(LlamaModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -139,7 +138,6 @@ class GraniteModel(LlamaModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -220,13 +218,12 @@ class GraniteModel(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -244,7 +241,6 @@ class GraniteForCausalLM(LlamaForCausalLM): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], @@ -253,10 +249,9 @@ class GraniteForCausalLM(LlamaForCausalLM): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -265,12 +260,11 @@ class GraniteForCausalLM(LlamaForCausalLM): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -280,10 +274,6 @@ class GraniteForCausalLM(LlamaForCausalLM): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index be55e4ebf9..5e66ce7298 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -45,6 +45,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -513,6 +514,7 @@ class HeliumModel(HeliumPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(HELIUM_INPUTS_DOCSTRING) def forward( self, @@ -524,16 +526,14 @@ class HeliumModel(HeliumPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -616,13 +616,12 @@ class HeliumModel(HeliumPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -788,6 +787,7 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(HELIUM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -802,11 +802,10 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -842,10 +841,9 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -854,12 +852,11 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -868,10 +865,6 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -912,6 +905,7 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(HELIUM_INPUTS_DOCSTRING) def forward( self, @@ -924,17 +918,15 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -943,9 +935,8 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -975,10 +966,6 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1018,6 +1005,7 @@ class HeliumForTokenClassification(HeliumPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(HELIUM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1035,17 +1023,15 @@ class HeliumForTokenClassification(HeliumPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1054,9 +1040,8 @@ class HeliumForTokenClassification(HeliumPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1064,10 +1049,6 @@ class HeliumForTokenClassification(HeliumPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 4fba6ba89b..4e4e5f71bf 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -43,6 +43,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -1228,6 +1229,7 @@ class JambaModel(JambaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) def forward( self, @@ -1240,9 +1242,8 @@ class JambaModel(JambaPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: + ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -1252,8 +1253,6 @@ class JambaModel(JambaPreTrainedModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1339,12 +1338,6 @@ class JambaModel(JambaPreTrainedModel): next_cache = None if not use_cache else past_key_values - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1433,6 +1426,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1448,11 +1442,10 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1493,10 +1486,9 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1507,10 +1499,9 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, cache_position=cache_position, - return_dict=return_dict, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1521,7 +1512,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -1529,12 +1520,6 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, @@ -1621,7 +1606,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): """, JAMBA_START_DOCSTRING, ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Jamba, MIXTRAL->JAMBA +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Jamba, MIXTRAL->JAMBA, BaseModelOutputWithPast->MoeModelOutputWithPast class JambaForSequenceClassification(JambaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1638,6 +1623,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) def forward( self, @@ -1650,17 +1636,15 @@ class JambaForSequenceClassification(JambaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: MoeModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1669,9 +1653,8 @@ class JambaForSequenceClassification(JambaPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1701,10 +1684,6 @@ class JambaForSequenceClassification(JambaPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index c7457db3e9..b87ca1376e 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -37,6 +37,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -982,6 +983,7 @@ class JetMoeModel(JetMoePreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) def forward( self, @@ -994,9 +996,8 @@ class JetMoeModel(JetMoePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: + ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1005,7 +1006,6 @@ class JetMoeModel(JetMoePreTrainedModel): output_router_logits if output_router_logits is not None else self.config.output_router_logits ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1110,8 +1110,6 @@ class JetMoeModel(JetMoePreTrainedModel): if return_legacy_cache: next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1289,6 +1287,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1304,11 +1303,10 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1329,10 +1327,9 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1341,11 +1338,10 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1372,7 +1368,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -1380,12 +1376,6 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): if labels is not None: loss += self.aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, @@ -1412,7 +1402,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): """, JETMOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->JetMoe, LLAMA->JETMOE +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->JetMoe, LLAMA->JETMOE, BaseModelOutputWithPast->MoeModelOutputWithPast class JetMoeForSequenceClassification(JetMoePreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1429,6 +1419,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) def forward( self, @@ -1441,17 +1432,15 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: MoeModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1460,9 +1449,8 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1492,10 +1480,6 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 78cf7a930a..f7d8c714d8 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -45,6 +45,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -515,6 +516,7 @@ class LlamaModel(LlamaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -526,16 +528,14 @@ class LlamaModel(LlamaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -618,13 +618,12 @@ class LlamaModel(LlamaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -790,6 +789,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -804,11 +804,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -844,10 +843,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -856,12 +854,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -870,10 +867,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -914,6 +907,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -926,17 +920,15 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -945,9 +937,8 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -977,10 +968,6 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1015,6 +1002,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): def set_input_embeddings(self, value): self.transformer.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -1027,9 +1015,8 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1040,9 +1027,8 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1050,10 +1036,9 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1064,10 +1049,6 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, @@ -1107,6 +1088,7 @@ class LlamaForTokenClassification(LlamaPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1124,17 +1106,15 @@ class LlamaForTokenClassification(LlamaPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1143,9 +1123,8 @@ class LlamaForTokenClassification(LlamaPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1153,10 +1132,6 @@ class LlamaForTokenClassification(LlamaPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c7b9a4523d..65044b4dbe 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,6 +30,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -480,6 +481,7 @@ class MistralModel(MistralPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) def forward( self, @@ -491,16 +493,14 @@ class MistralModel(MistralPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -583,13 +583,12 @@ class MistralModel(MistralPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -779,6 +778,7 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -793,11 +793,10 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -833,10 +832,9 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -845,12 +843,11 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -859,10 +856,6 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -902,6 +895,7 @@ class MistralForTokenClassification(MistralPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -919,17 +913,15 @@ class MistralForTokenClassification(MistralPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -938,9 +930,8 @@ class MistralForTokenClassification(MistralPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -948,10 +939,6 @@ class MistralForTokenClassification(MistralPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -991,6 +978,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) def forward( self, @@ -1003,17 +991,15 @@ class MistralForSequenceClassification(MistralPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1022,9 +1008,8 @@ class MistralForSequenceClassification(MistralPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1054,10 +1039,6 @@ class MistralForSequenceClassification(MistralPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1091,6 +1072,7 @@ class MistralForQuestionAnswering(MistralPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) def forward( self, @@ -1103,9 +1085,8 @@ class MistralForQuestionAnswering(MistralPreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1116,9 +1097,8 @@ class MistralForQuestionAnswering(MistralPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1126,10 +1106,9 @@ class MistralForQuestionAnswering(MistralPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1140,10 +1119,6 @@ class MistralForQuestionAnswering(MistralPreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 20f0528627..4548e69a73 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -7,7 +7,7 @@ from torch import nn from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import QuestionAnsweringModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, QuestionAnsweringModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging @@ -302,7 +302,6 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" @@ -315,9 +314,8 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -325,10 +323,9 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -339,10 +336,6 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 13e14a755d..7cfc266fd3 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -53,6 +53,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -602,6 +603,7 @@ class MixtralModel(MixtralPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, @@ -614,10 +616,9 @@ class MixtralModel(MixtralPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -627,8 +628,6 @@ class MixtralModel(MixtralPreTrainedModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -712,14 +711,13 @@ class MixtralModel(MixtralPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = MoeModelOutputWithPast( + return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -994,6 +992,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1009,11 +1008,10 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1054,10 +1052,9 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1067,12 +1064,11 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1084,7 +1080,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -1092,12 +1088,6 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, @@ -1140,6 +1130,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, @@ -1152,17 +1143,15 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1171,9 +1160,8 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1203,10 +1191,6 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1246,6 +1230,7 @@ class MixtralForTokenClassification(MixtralPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1263,17 +1248,15 @@ class MixtralForTokenClassification(MixtralPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1282,9 +1265,8 @@ class MixtralForTokenClassification(MixtralPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1292,10 +1274,6 @@ class MixtralForTokenClassification(MixtralPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1328,6 +1306,7 @@ class MixtralForQuestionAnswering(MixtralPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, @@ -1340,9 +1319,8 @@ class MixtralForQuestionAnswering(MixtralPreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1353,9 +1331,8 @@ class MixtralForQuestionAnswering(MixtralPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1363,10 +1340,9 @@ class MixtralForQuestionAnswering(MixtralPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1377,10 +1353,6 @@ class MixtralForQuestionAnswering(MixtralPreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index c7fa30376b..3b470667e9 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -342,10 +342,9 @@ class MixtralModel(MistralModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, MoeModelOutputWithPast]: + ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -355,8 +354,6 @@ class MixtralModel(MistralModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -440,14 +437,13 @@ class MixtralModel(MistralModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = MoeModelOutputWithPast( + return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) - return output if return_dict else output.to_tuple() class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -475,11 +471,10 @@ class MixtralForCausalLM(MistralForCausalLM): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -520,10 +515,9 @@ class MixtralForCausalLM(MistralForCausalLM): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -533,12 +527,11 @@ class MixtralForCausalLM(MistralForCausalLM): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -550,7 +543,7 @@ class MixtralForCausalLM(MistralForCausalLM): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -558,12 +551,6 @@ class MixtralForCausalLM(MistralForCausalLM): if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 78438151b8..17e2b9584e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -47,6 +47,7 @@ from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -615,15 +616,15 @@ class MoonshineEncoder(MoonshinePreTrainedModel): def set_input_embeddings(self, value: nn.Module): self.conv1 = value + @can_return_tuple def forward( self, input_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): @@ -650,7 +651,6 @@ class MoonshineEncoder(MoonshinePreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_values is None: raise ValueError("You must specify input_values.") @@ -725,12 +725,11 @@ class MoonshineEncoder(MoonshinePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() MOONSHINE_INPUTS_DOCSTRING = r""" @@ -836,6 +835,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING) def forward( self, @@ -847,12 +847,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: """ Args: encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): @@ -869,7 +868,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -977,14 +975,13 @@ class MoonshineDecoder(MoonshinePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -1399,6 +1396,7 @@ class MoonshineModel(MoonshinePreTrainedModel): return input_features + @can_return_tuple @add_start_docstrings_to_model_forward(MOONSHINE_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1414,7 +1412,6 @@ class MoonshineModel(MoonshinePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" @@ -1442,18 +1439,16 @@ class MoonshineModel(MoonshinePreTrainedModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + elif not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, @@ -1461,24 +1456,20 @@ class MoonshineModel(MoonshinePreTrainedModel): ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( + decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_attention_mask=attention_mask, - encoder_hidden_states=encoder_outputs[0], + encoder_hidden_states=encoder_outputs.last_hidden_state, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - if not return_dict: - return decoder_outputs + encoder_outputs - return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, @@ -1537,6 +1528,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi def get_input_embeddings(self) -> nn.Module: return self.model.get_input_embeddings() + @can_return_tuple @add_start_docstrings_to_model_forward(MOONSHINE_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1552,10 +1544,9 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + ) -> Seq2SeqLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` @@ -1585,7 +1576,6 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi >>> transcription 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None and decoder_inputs_embeds is None: @@ -1593,7 +1583,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi labels, self.config.pad_token_id, self.config.decoder_start_token_id ) - outputs = self.model( + outputs: Seq2SeqModelOutput = self.model( input_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -1605,19 +1595,14 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - logits = self.proj_out(outputs[0]) + logits = self.proj_out(outputs.last_hidden_state) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index f1fdd7c58d..02938e73d5 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -40,6 +40,7 @@ from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -605,15 +606,15 @@ class MoonshineEncoder(MoonshinePreTrainedModel): def set_input_embeddings(self, value: nn.Module): self.conv1 = value + @can_return_tuple def forward( self, input_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): @@ -640,7 +641,6 @@ class MoonshineEncoder(MoonshinePreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_values is None: raise ValueError("You must specify input_values.") @@ -715,12 +715,11 @@ class MoonshineEncoder(MoonshinePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class MoonshineDecoder(LlamaModel): @@ -743,7 +742,6 @@ class MoonshineDecoder(LlamaModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, @@ -765,7 +763,6 @@ class MoonshineDecoder(LlamaModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -873,14 +870,13 @@ class MoonshineDecoder(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) - return output if return_dict else output.to_tuple() MOONSHINE_MODEL_INPUTS_DOCSTRING = r""" @@ -978,6 +974,7 @@ MOONSHINE_MODEL_INPUTS_DOCSTRING = r""" MOONSHINE_START_DOCSTRING, ) class MoonshineModel(WhisperModel): + @can_return_tuple @add_start_docstrings_to_model_forward(MOONSHINE_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -993,9 +990,8 @@ class MoonshineModel(WhisperModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + ) -> Seq2SeqModelOutput: r""" ```python >>> import torch @@ -1017,18 +1013,16 @@ class MoonshineModel(WhisperModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + elif not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, @@ -1036,24 +1030,20 @@ class MoonshineModel(WhisperModel): ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( + decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_attention_mask=attention_mask, - encoder_hidden_states=encoder_outputs[0], + encoder_hidden_states=encoder_outputs.last_hidden_state, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - if not return_dict: - return decoder_outputs + encoder_outputs - return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, @@ -1096,6 +1086,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi def get_input_embeddings(self) -> nn.Module: return self.model.get_input_embeddings() + @can_return_tuple @add_start_docstrings_to_model_forward(MOONSHINE_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1111,10 +1102,9 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + ) -> Seq2SeqLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` @@ -1144,7 +1134,6 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi >>> transcription 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None and decoder_inputs_embeds is None: @@ -1152,7 +1141,7 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi labels, self.config.pad_token_id, self.config.decoder_start_token_id ) - outputs = self.model( + outputs: Seq2SeqModelOutput = self.model( input_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -1164,19 +1153,14 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - logits = self.proj_out(outputs[0]) + logits = self.proj_out(outputs.last_hidden_state) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index b1cae89b0b..6a362804a2 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -42,6 +42,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -760,6 +761,7 @@ class NemotronModel(NemotronPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) def forward( self, @@ -771,15 +773,13 @@ class NemotronModel(NemotronPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -866,8 +866,6 @@ class NemotronModel(NemotronPreTrainedModel): next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1037,6 +1035,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1052,11 +1051,10 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1092,10 +1090,9 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1104,11 +1101,10 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1117,10 +1113,6 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1162,6 +1154,7 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) def forward( self, @@ -1174,17 +1167,15 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1193,9 +1184,8 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1225,10 +1215,6 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1264,6 +1250,7 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel): def set_input_embeddings(self, value): self.transformer.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) def forward( self, @@ -1276,9 +1263,8 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1289,9 +1275,8 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1299,10 +1284,9 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1313,10 +1297,6 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, @@ -1357,6 +1337,7 @@ class NemotronForTokenClassification(NemotronPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1374,17 +1355,15 @@ class NemotronForTokenClassification(NemotronPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1393,9 +1372,8 @@ class NemotronForTokenClassification(NemotronPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1403,10 +1381,6 @@ class NemotronForTokenClassification(NemotronPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 23acd45eb2..79b8c28ca2 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -24,6 +24,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -491,6 +492,7 @@ class OlmoModel(OlmoPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) def forward( self, @@ -502,16 +504,14 @@ class OlmoModel(OlmoPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -594,13 +594,12 @@ class OlmoModel(OlmoPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -766,6 +765,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -780,11 +780,10 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -820,10 +819,9 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -832,12 +830,11 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -846,10 +843,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 9af94ae0aa..cc83dd4d17 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -23,6 +23,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -492,6 +493,7 @@ class Olmo2Model(Olmo2PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) def forward( self, @@ -503,16 +505,14 @@ class Olmo2Model(Olmo2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -595,13 +595,12 @@ class Olmo2Model(Olmo2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -767,6 +766,7 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -781,11 +781,10 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -821,10 +820,9 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -833,12 +831,11 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -847,10 +844,6 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 1065e801a4..f14e2d4535 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -23,6 +23,7 @@ from torch import nn from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -537,7 +538,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) - outputs = self.language_model( + outputs: CausalLMOutputWithPast = self.language_model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -545,7 +546,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, logits_to_keep=logits_to_keep, **lm_kwargs, @@ -573,11 +574,8 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return PaliGemmaCausalLMOutputWithPast( + output = PaliGemmaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, @@ -585,6 +583,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) + return output if return_dict else output.to_tuple() def prepare_inputs_for_generation( self, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 995770b35c..1e3a784026 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -42,6 +42,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -550,6 +551,7 @@ class PersimmonModel(PersimmonPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) def forward( self, @@ -561,17 +563,14 @@ class PersimmonModel(PersimmonPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -667,8 +666,6 @@ class PersimmonModel(PersimmonPreTrainedModel): if return_legacy_cache: next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -844,6 +841,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -858,11 +856,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -899,10 +896,9 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -911,11 +907,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # No upscaling to float was ever done for Persimmon slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -929,10 +924,6 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): **kwargs, ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -974,6 +965,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) def forward( self, @@ -986,17 +978,15 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1005,9 +995,8 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1037,10 +1026,6 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1081,6 +1066,7 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1098,17 +1084,15 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1117,9 +1101,8 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1127,10 +1110,6 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index a5a008a6f1..612bd70407 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -29,6 +29,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -488,6 +489,7 @@ class PhiModel(PhiPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) def forward( self, @@ -499,16 +501,14 @@ class PhiModel(PhiPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -588,13 +588,12 @@ class PhiModel(PhiPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -760,6 +759,7 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -774,11 +774,10 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -814,10 +813,9 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -826,12 +824,11 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -840,10 +837,6 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -884,6 +877,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) def forward( self, @@ -896,17 +890,15 @@ class PhiForSequenceClassification(PhiPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -915,9 +907,8 @@ class PhiForSequenceClassification(PhiPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -947,10 +938,6 @@ class PhiForSequenceClassification(PhiPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -990,6 +977,7 @@ class PhiForTokenClassification(PhiPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1007,17 +995,15 @@ class PhiForTokenClassification(PhiPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1026,9 +1012,8 @@ class PhiForTokenClassification(PhiPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1036,10 +1021,6 @@ class PhiForTokenClassification(PhiPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 4dcf74d741..e01d433aa1 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -189,7 +189,6 @@ class PhiModel(LlamaModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -198,7 +197,6 @@ class PhiModel(LlamaModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -278,13 +276,12 @@ class PhiModel(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class PhiForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index bd781216da..0d8238683e 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -45,6 +45,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -555,6 +556,7 @@ class Phi3Model(Phi3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) def forward( self, @@ -566,16 +568,14 @@ class Phi3Model(Phi3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -658,13 +658,12 @@ class Phi3Model(Phi3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -854,6 +853,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -868,11 +868,10 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -908,10 +907,9 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -920,12 +918,11 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -934,10 +931,6 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1017,6 +1010,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) def forward( self, @@ -1029,17 +1023,15 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1048,9 +1040,8 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1080,10 +1071,6 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1123,6 +1110,7 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1140,17 +1128,15 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1159,9 +1145,8 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1169,10 +1154,6 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 5d44fae131..5754459529 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -47,6 +47,7 @@ from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, torch_int, @@ -212,14 +213,14 @@ class Phi4MultimodalVisionEncoder(nn.Module): self.gradient_checkpointing = False # Ignore copy + @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -246,7 +247,6 @@ class Phi4MultimodalVisionEncoder(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -277,10 +277,10 @@ class Phi4MultimodalVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) @@ -567,13 +567,11 @@ class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): patch_attention_mask: Optional[torch.BoolTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = pixel_values.size(0) if patch_attention_mask is None: @@ -602,15 +600,14 @@ class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): else patch_attention_mask ) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = self.head( @@ -618,9 +615,6 @@ class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): attention_mask=patch_attention_mask, ) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1845,6 +1839,7 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) def forward( self, @@ -1862,18 +1857,15 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1961,13 +1953,12 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -2154,6 +2145,7 @@ class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) def forward( @@ -2173,11 +2165,10 @@ class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -2209,10 +2200,9 @@ class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -2227,12 +2217,11 @@ class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -2241,10 +2230,6 @@ class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 06424941ec..c9e9f209ad 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import DynamicCache from ...configuration_utils import PretrainedConfig from ...modeling_outputs import ( + BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast, @@ -34,6 +35,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -668,13 +670,11 @@ class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): patch_attention_mask: Optional[torch.BoolTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = pixel_values.size(0) if patch_attention_mask is None: @@ -703,15 +703,14 @@ class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): else patch_attention_mask ) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = self.head( @@ -719,9 +718,6 @@ class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): attention_mask=patch_attention_mask, ) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1549,6 +1545,7 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) def forward( self, @@ -1566,18 +1563,15 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1665,13 +1659,12 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): @@ -1686,6 +1679,7 @@ class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) def forward( @@ -1705,11 +1699,10 @@ class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1741,10 +1734,9 @@ class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1759,12 +1751,11 @@ class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1773,10 +1764,6 @@ class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 7701a76e6c..89e82b7d6b 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -37,6 +37,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -1031,6 +1032,7 @@ class PhimoeModel(PhimoePreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) def forward( self, @@ -1043,9 +1045,8 @@ class PhimoeModel(PhimoePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: + ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -1055,8 +1056,6 @@ class PhimoeModel(PhimoePreTrainedModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" @@ -1159,12 +1158,6 @@ class PhimoeModel(PhimoePreTrainedModel): if return_legacy_cache: next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1366,6 +1359,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1382,11 +1376,10 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1429,10 +1422,9 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1442,11 +1434,10 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, - return_dict=return_dict, cache_position=cache_position, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1458,7 +1449,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -1466,12 +1457,6 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, @@ -1537,7 +1522,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): PHIMOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phimoe, LLAMA->PHIMOE +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phimoe, LLAMA->PHIMOE, BaseModelOutputWithPast->MoeModelOutputWithPast class PhimoeForSequenceClassification(PhimoePreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1554,6 +1539,7 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) def forward( self, @@ -1566,17 +1552,15 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: MoeModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1585,9 +1569,8 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1617,10 +1600,6 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index e009b6f693..232598b231 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -30,6 +30,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -493,6 +494,7 @@ class Qwen2Model(Qwen2PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) def forward( self, @@ -504,16 +506,14 @@ class Qwen2Model(Qwen2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -596,13 +596,12 @@ class Qwen2Model(Qwen2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -792,6 +791,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -806,11 +806,10 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -846,10 +845,9 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -858,12 +856,11 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -872,10 +869,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -916,6 +909,7 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) def forward( self, @@ -928,17 +922,15 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -947,9 +939,8 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -979,10 +970,6 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1022,6 +1009,7 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1039,17 +1027,15 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1058,9 +1044,8 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1068,10 +1053,6 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1104,6 +1085,7 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): def set_input_embeddings(self, value): self.transformer.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) def forward( self, @@ -1116,9 +1098,8 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1129,9 +1110,8 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1139,10 +1119,9 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1153,10 +1132,6 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 40263884b1..8e1e8de79e 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -45,6 +45,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -919,6 +920,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, @@ -931,9 +933,8 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: + ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -943,8 +944,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1046,12 +1045,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): if return_legacy_cache: next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1250,6 +1243,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1265,11 +1259,10 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1309,10 +1302,9 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1322,11 +1314,10 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, - return_dict=return_dict, cache_position=cache_position, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1338,7 +1329,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -1346,12 +1337,6 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, @@ -1378,7 +1363,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): """, QWEN2MOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE, BaseModelOutputWithPast->MoeModelOutputWithPast class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1395,6 +1380,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, @@ -1407,17 +1393,15 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: MoeModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1426,9 +1410,8 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1458,10 +1441,6 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1478,7 +1457,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): """, QWEN2MOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE, BaseModelOutputWithPast->MoeModelOutputWithPast class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1502,6 +1481,7 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1519,17 +1499,15 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1538,9 +1516,8 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1548,10 +1525,6 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1567,7 +1540,7 @@ SQuAD (a linear layer on top of the hidden-states output to compute `span start """, QWEN2MOE_START_DOCSTRING, ) -# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE +# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE, BaseModelOutputWithPast->MoeModelOutputWithPast class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): base_model_prefix = "model" @@ -1585,6 +1558,7 @@ class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, @@ -1597,9 +1571,8 @@ class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1610,9 +1583,8 @@ class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1620,10 +1592,9 @@ class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1634,10 +1605,6 @@ class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 069691cee2..b235c8acde 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -45,6 +45,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -520,6 +521,7 @@ class Qwen3Model(Qwen3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) def forward( self, @@ -531,16 +533,14 @@ class Qwen3Model(Qwen3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -623,13 +623,12 @@ class Qwen3Model(Qwen3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -819,6 +818,7 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -833,11 +833,10 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -873,10 +872,9 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -885,12 +883,11 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -899,10 +896,6 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -943,6 +936,7 @@ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) def forward( self, @@ -955,17 +949,15 @@ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -974,9 +966,8 @@ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1006,10 +997,6 @@ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1049,6 +1036,7 @@ class Qwen3ForTokenClassification(Qwen3PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1066,17 +1054,15 @@ class Qwen3ForTokenClassification(Qwen3PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1085,9 +1071,8 @@ class Qwen3ForTokenClassification(Qwen3PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1095,10 +1080,6 @@ class Qwen3ForTokenClassification(Qwen3PreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1131,6 +1112,7 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): def set_input_embeddings(self, value): self.transformer.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) def forward( self, @@ -1143,9 +1125,8 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1156,9 +1137,8 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1166,10 +1146,9 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1180,10 +1159,6 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index a697289993..3897cd44bd 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -48,6 +48,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -615,6 +616,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) def forward( self, @@ -627,10 +629,9 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -640,8 +641,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -725,14 +724,13 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = MoeModelOutputWithPast( + return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -1007,6 +1005,7 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1022,11 +1021,10 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1067,10 +1065,9 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1080,12 +1077,11 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1097,7 +1093,7 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -1105,12 +1101,6 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, @@ -1153,6 +1143,7 @@ class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) def forward( self, @@ -1165,17 +1156,15 @@ class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1184,9 +1173,8 @@ class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1216,10 +1204,6 @@ class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1259,6 +1243,7 @@ class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1276,17 +1261,15 @@ class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1295,9 +1278,8 @@ class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1305,10 +1287,6 @@ class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1341,6 +1319,7 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): def set_input_embeddings(self, value): self.transformer.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) def forward( self, @@ -1353,9 +1332,8 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -1366,9 +1344,8 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1376,10 +1353,9 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -1390,10 +1366,6 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index 31969dc885..3e80e28a0c 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -23,7 +23,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import MoeCausalLMOutputWithPast +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import ( LossKwargs, @@ -254,7 +254,6 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], @@ -299,10 +298,9 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -312,12 +310,11 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -329,7 +326,7 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM): aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -337,12 +334,6 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM): if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index dd36b23ba6..b25d1b7318 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -16,7 +16,7 @@ import collections from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -31,6 +31,7 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -406,13 +407,11 @@ class SamTwoWayTransformer(nn.Module): target_embedding=None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict all_attentions = () @@ -1121,18 +1120,17 @@ class SamVisionEncoder(nn.Module): def get_input_embeddings(self): return self.patch_embed + @can_return_tuple def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SamVisionEncoderOutput]: + ) -> SamVisionEncoderOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -1166,14 +1164,6 @@ class SamVisionEncoder(nn.Module): hidden_states = self.neck(hidden_states) - if not return_dict: - outputs = (hidden_states,) - if output_hidden_states: - outputs = outputs + (all_hidden_states,) - if output_attentions: - outputs = outputs + (all_self_attentions,) - return outputs - return SamVisionEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, @@ -1396,7 +1386,6 @@ class SamModel(SamPreTrainedModel): pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ): r""" Returns the image embeddings by passing the pixel values through the vision encoder. @@ -1408,15 +1397,11 @@ class SamModel(SamPreTrainedModel): Whether or not to return the attentions tensors of all attention layers. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ vision_output = self.vision_encoder( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) image_embeddings = vision_output[0] return image_embeddings @@ -1454,6 +1439,7 @@ class SamModel(SamPreTrainedModel): ) return prompt_output + @can_return_tuple @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) def forward( self, @@ -1468,9 +1454,8 @@ class SamModel(SamPreTrainedModel): target_embedding: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs, - ) -> List[Dict[str, torch.Tensor]]: + ) -> SamImageSegmentationOutput: r""" Example: @@ -1500,7 +1485,6 @@ class SamModel(SamPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None and image_embeddings is None: raise ValueError("Either pixel_values or image_embeddings must be provided.") @@ -1537,18 +1521,17 @@ class SamModel(SamPreTrainedModel): vision_hidden_states = None if pixel_values is not None: - vision_outputs = self.vision_encoder( + vision_outputs: SamVisionEncoderOutput = self.vision_encoder( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - image_embeddings = vision_outputs[0] + image_embeddings = vision_outputs.last_hidden_state if output_hidden_states: - vision_hidden_states = vision_outputs[1] + vision_hidden_states = vision_outputs.hidden_states if output_attentions: - vision_attentions = vision_outputs[-1] + vision_attentions = vision_outputs.attentions if input_points is not None and input_labels is None: input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) @@ -1580,15 +1563,6 @@ class SamModel(SamPreTrainedModel): output_attentions=output_attentions, ) - if not return_dict: - output = (iou_predictions, low_res_masks) - if output_hidden_states: - output = output + (vision_hidden_states,) - - if output_attentions: - output = output + (vision_attentions, mask_decoder_attentions) - return output - return SamImageSegmentationOutput( iou_scores=iou_predictions, pred_masks=low_res_masks, diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 34b0ee370c..0288b78381 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -17,7 +17,7 @@ import math import warnings from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple import numpy as np import torch @@ -35,6 +35,7 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, torch_int, @@ -848,14 +849,14 @@ class SiglipEncoder(nn.Module): self.gradient_checkpointing = False # Ignore copy + @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -882,7 +883,6 @@ class SiglipEncoder(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -913,10 +913,10 @@ class SiglipEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) @@ -932,6 +932,7 @@ class SiglipTextTransformer(nn.Module): self.head = nn.Linear(embed_dim, config.projection_size) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) def forward( @@ -941,8 +942,7 @@ class SiglipTextTransformer(nn.Module): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -951,7 +951,6 @@ class SiglipTextTransformer(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is None: raise ValueError("You have to specify input_ids") @@ -967,24 +966,20 @@ class SiglipTextTransformer(nn.Module): # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.final_layer_norm(last_hidden_state) # Assuming "sticky" EOS tokenization, last token is always EOS. pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1012,6 +1007,7 @@ class SiglipTextModel(SiglipPreTrainedModel): def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) def forward( @@ -1021,8 +1017,7 @@ class SiglipTextModel(SiglipPreTrainedModel): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1041,7 +1036,6 @@ class SiglipTextModel(SiglipPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.text_model( input_ids=input_ids, @@ -1049,7 +1043,6 @@ class SiglipTextModel(SiglipPreTrainedModel): position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) @@ -1066,6 +1059,7 @@ class SiglipVisionTransformer(nn.Module): if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead(config) + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) def forward( @@ -1073,9 +1067,8 @@ class SiglipVisionTransformer(nn.Module): pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1084,23 +1077,19 @@ class SiglipVisionTransformer(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None - if not return_dict: - return (last_hidden_state, pooler_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, @@ -1153,6 +1142,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) def forward( @@ -1160,9 +1150,8 @@ class SiglipVisionModel(SiglipPreTrainedModel): pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1185,13 +1174,11 @@ class SiglipVisionModel(SiglipPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) @@ -1240,7 +1227,6 @@ class SiglipModel(SiglipPreTrainedModel): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: @@ -1266,18 +1252,16 @@ class SiglipModel(SiglipPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - pooled_output = text_outputs[1] + pooled_output = text_outputs.pooler_output return pooled_output @@ -1287,7 +1271,6 @@ class SiglipModel(SiglipPreTrainedModel): pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" @@ -1319,20 +1302,19 @@ class SiglipModel(SiglipPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) - pooled_output = vision_outputs[1] + pooled_output = vision_outputs.pooler_output return pooled_output + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) def forward( @@ -1344,9 +1326,8 @@ class SiglipModel(SiglipPreTrainedModel): return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, - ) -> Union[Tuple, SiglipOutput]: + ) -> SiglipOutput: r""" Returns: @@ -1381,27 +1362,24 @@ class SiglipModel(SiglipPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - image_embeds = vision_outputs[1] - text_embeds = text_outputs[1] + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) @@ -1424,10 +1402,6 @@ class SiglipModel(SiglipPreTrainedModel): nll = -torch.sum(loglik, dim=-1) loss = nll.mean() - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - return SiglipOutput( loss=loss, logits_per_image=logits_per_image, @@ -1467,6 +1441,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1475,9 +1450,8 @@ class SiglipForImageClassification(SiglipPreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, - ) -> Union[tuple, ImageClassifierOutput]: + ) -> ImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., @@ -1515,17 +1489,15 @@ class SiglipForImageClassification(SiglipPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.vision_model( + outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state # average pool the patch tokens sequence_output = torch.mean(sequence_output, dim=1) @@ -1557,10 +1529,6 @@ class SiglipForImageClassification(SiglipPreTrainedModel): loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return ImageClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index 4fbe921483..663d9f3dd0 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -21,7 +21,7 @@ import math import warnings from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple import numpy as np import torch @@ -39,6 +39,7 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -566,14 +567,14 @@ class Siglip2Encoder(nn.Module): self.gradient_checkpointing = False # Ignore copy + @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -600,7 +601,6 @@ class Siglip2Encoder(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -631,10 +631,10 @@ class Siglip2Encoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) @@ -670,6 +670,7 @@ class Siglip2VisionTransformer(nn.Module): self.head = Siglip2MultiheadAttentionPoolingHead(config) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig) def forward( @@ -679,8 +680,7 @@ class Siglip2VisionTransformer(nn.Module): spatial_shapes: torch.LongTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -689,7 +689,6 @@ class Siglip2VisionTransformer(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = self.embeddings(pixel_values, spatial_shapes) @@ -699,20 +698,17 @@ class Siglip2VisionTransformer(nn.Module): else: encoder_attention_mask = attention_mask - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None - if not return_dict: - return (last_hidden_state, pooler_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, @@ -902,6 +898,7 @@ class Siglip2TextTransformer(nn.Module): self.head = nn.Linear(embed_dim, config.projection_size) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig) def forward( @@ -911,8 +908,7 @@ class Siglip2TextTransformer(nn.Module): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -921,7 +917,6 @@ class Siglip2TextTransformer(nn.Module): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is None: raise ValueError("You have to specify input_ids") @@ -937,24 +932,20 @@ class Siglip2TextTransformer(nn.Module): # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.final_layer_norm(last_hidden_state) # Assuming "sticky" EOS tokenization, last token is always EOS. pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1104,6 +1095,7 @@ class Siglip2TextModel(Siglip2PreTrainedModel): def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig) def forward( @@ -1113,8 +1105,7 @@ class Siglip2TextModel(Siglip2PreTrainedModel): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1133,7 +1124,6 @@ class Siglip2TextModel(Siglip2PreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.text_model( input_ids=input_ids, @@ -1141,7 +1131,6 @@ class Siglip2TextModel(Siglip2PreTrainedModel): position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) @@ -1195,6 +1184,7 @@ class Siglip2VisionModel(Siglip2PreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig) def forward( @@ -1204,8 +1194,7 @@ class Siglip2VisionModel(Siglip2PreTrainedModel): spatial_shapes: torch.LongTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: @@ -1228,15 +1217,12 @@ class Siglip2VisionModel(Siglip2PreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - return self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) @@ -1284,7 +1270,6 @@ class Siglip2Model(Siglip2PreTrainedModel): position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: @@ -1310,18 +1295,16 @@ class Siglip2Model(Siglip2PreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - pooled_output = text_outputs[1] + pooled_output = text_outputs.pooler_output return pooled_output @@ -1333,7 +1316,6 @@ class Siglip2Model(Siglip2PreTrainedModel): spatial_shapes: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: @@ -1364,21 +1346,20 @@ class Siglip2Model(Siglip2PreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - pooled_output = vision_outputs[1] + pooled_output = vision_outputs.pooler_output return pooled_output + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Siglip2Output, config_class=Siglip2Config) def forward( @@ -1392,8 +1373,7 @@ class Siglip2Model(Siglip2PreTrainedModel): return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Siglip2Output]: + ) -> Siglip2Output: r""" Returns: @@ -1428,28 +1408,25 @@ class Siglip2Model(Siglip2PreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - image_embeds = vision_outputs[1] - text_embeds = text_outputs[1] + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) @@ -1472,10 +1449,6 @@ class Siglip2Model(Siglip2PreTrainedModel): nll = -torch.sum(loglik, dim=-1) loss = nll.mean() - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - return Siglip2Output( loss=loss, logits_per_image=logits_per_image, @@ -1515,6 +1488,7 @@ class Siglip2ForImageClassification(Siglip2PreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1525,8 +1499,7 @@ class Siglip2ForImageClassification(Siglip2PreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, ImageClassifierOutput]: + ) -> ImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., @@ -1564,18 +1537,16 @@ class Siglip2ForImageClassification(Siglip2PreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.vision_model( + outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state # average pool the patch tokens if pixel_attention_mask is not None: @@ -1612,10 +1583,6 @@ class Siglip2ForImageClassification(Siglip2PreTrainedModel): loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return ImageClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/siglip2/modular_siglip2.py b/src/transformers/models/siglip2/modular_siglip2.py index 6fac003051..92e106bc59 100644 --- a/src/transformers/models/siglip2/modular_siglip2.py +++ b/src/transformers/models/siglip2/modular_siglip2.py @@ -21,6 +21,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers.models.siglip.modeling_siglip import ( + BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, SiglipForImageClassification, @@ -242,7 +243,6 @@ class Siglip2VisionTransformer(SiglipVisionTransformer): spatial_shapes: torch.LongTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -252,7 +252,6 @@ class Siglip2VisionTransformer(SiglipVisionTransformer): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = self.embeddings(pixel_values, spatial_shapes) @@ -262,20 +261,17 @@ class Siglip2VisionTransformer(SiglipVisionTransformer): else: encoder_attention_mask = attention_mask - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None - if not return_dict: - return (last_hidden_state, pooler_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, @@ -326,17 +322,13 @@ class Siglip2VisionModel(SiglipVisionModel): spatial_shapes: torch.LongTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + ) -> BaseModelOutputWithPooling: return self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) @@ -349,25 +341,22 @@ class Siglip2Model(SiglipModel): spatial_shapes: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> torch.FloatTensor: # Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - pooled_output = vision_outputs[1] + pooled_output = vision_outputs.pooler_output return pooled_output @@ -383,35 +372,31 @@ class Siglip2Model(SiglipModel): return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Siglip2Output]: + ) -> Siglip2Output: # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( + vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - text_outputs = self.text_model( + text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - image_embeds = vision_outputs[1] - text_embeds = text_outputs[1] + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) @@ -434,10 +419,6 @@ class Siglip2Model(SiglipModel): nll = -torch.sum(loglik, dim=-1) loss = nll.mean() - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - return Siglip2Output( loss=loss, logits_per_image=logits_per_image, @@ -459,24 +440,21 @@ class Siglip2ForImageClassification(SiglipForImageClassification): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, ImageClassifierOutput]: + ) -> ImageClassifierOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.vision_model( + outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state # average pool the patch tokens if pixel_attention_mask is not None: @@ -513,10 +491,6 @@ class Siglip2ForImageClassification(SiglipForImageClassification): loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return ImageClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index a11e0627b8..ab14bdb8e6 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -43,6 +43,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -804,6 +805,7 @@ class StableLmModel(StableLmPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) def forward( self, @@ -815,17 +817,14 @@ class StableLmModel(StableLmPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -921,8 +920,6 @@ class StableLmModel(StableLmPreTrainedModel): if return_legacy_cache: next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1099,6 +1096,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -1114,11 +1112,10 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1155,9 +1152,8 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1166,11 +1162,10 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # No upscaling to float was ever done for StableLm slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1184,10 +1179,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): **kwargs, ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1229,6 +1220,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) def forward( self, @@ -1241,17 +1233,15 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1260,9 +1250,8 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -1292,10 +1281,6 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1336,6 +1321,7 @@ class StableLmForTokenClassification(StableLmPreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1353,17 +1339,15 @@ class StableLmForTokenClassification(StableLmPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1372,9 +1356,8 @@ class StableLmForTokenClassification(StableLmPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1382,10 +1365,6 @@ class StableLmForTokenClassification(StableLmPreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 362856fe04..364f9a3d03 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -48,6 +48,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -485,6 +486,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) def forward( self, @@ -496,16 +498,14 @@ class Starcoder2Model(Starcoder2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -574,13 +574,12 @@ class Starcoder2Model(Starcoder2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -770,6 +769,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -784,11 +784,10 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -824,10 +823,9 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -836,12 +834,11 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -850,10 +847,6 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -894,6 +887,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) def forward( self, @@ -906,17 +900,15 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -925,9 +917,8 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -957,10 +948,6 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, @@ -1000,6 +987,7 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1017,17 +1005,15 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1036,9 +1022,8 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) @@ -1046,10 +1031,6 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel): if labels is not None: loss = self.loss_function(logits, labels, self.config) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 32d64cd167..d6aaa08f2c 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -33,7 +33,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import add_start_docstrings_to_model_forward, logging +from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging from ..mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, @@ -155,6 +155,7 @@ class Starcoder2Model(MistralModel): self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.embedding_dropout = config.embedding_dropout + @can_return_tuple @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) def forward( self, @@ -166,16 +167,14 @@ class Starcoder2Model(MistralModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -244,13 +243,12 @@ class Starcoder2Model(MistralModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class Starcoder2ForCausalLM(MistralForCausalLM): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index a427cb18f5..0209c85b85 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -45,6 +45,7 @@ from .generic import ( add_model_info_to_custom_pipelines, cached_property, can_return_loss, + can_return_tuple, expand_dims, filter_out_non_signature_kwargs, find_labels, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index f17410a20d..721ecaa37f 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -41,6 +41,11 @@ from .import_utils import ( ) +if is_torch_available(): + # required for @can_return_tuple decorator to work with torchdynamo + import torch # noqa: F401 + + class cached_property(property): """ Descriptor that mimics @property but caches output in member variable. @@ -909,3 +914,62 @@ def is_timm_local_checkpoint(pretrained_model_path: str) -> bool: return is_timm_config_dict(config_dict) return False + + +def set_attribute_for_modules(module: "torch.nn.Module", key: str, value: Any): + """ + Set a value to a module and all submodules. + """ + setattr(module, key, value) + for submodule in module.children(): + set_attribute_for_modules(submodule, key, value) + + +def del_attribute_from_modules(module: "torch.nn.Module", key: str): + """ + Delete a value from a module and all submodules. + """ + # because we might remove it previously in case it's a shared module, e.g. activation function + if hasattr(module, key): + delattr(module, key) + + for submodule in module.children(): + del_attribute_from_modules(submodule, key) + + +def can_return_tuple(func): + """ + Decorator to wrap model method, to call output.to_tuple() if return_dict=False passed as a kwarg or + use_return_dict=False is set in the config. + + Note: + output.to_tuple() convert output to tuple skipping all `None` values. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + is_requested_to_return_tuple = kwargs.pop("return_dict", True) is False + is_configured_to_return_tuple = self.config.use_return_dict is False if hasattr(self, "config") else False + + # The following allows to convert output to tuple ONLY on top level forward call, + # while internal modules of the model will return Output objects + # to be able to use name-based attribute access in modeling code. + + # We will check if we are on top level module, if so, turn off to tuple conversion for all + # underling calls. + is_top_level_module = getattr(self, "_is_top_level_module", True) + if is_configured_to_return_tuple and is_top_level_module: + set_attribute_for_modules(self, "_is_top_level_module", False) + + try: + output = func(self, *args, **kwargs) + if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module): + output = output.to_tuple() + finally: + # Remove the flag after the model forward call is finished. + if is_configured_to_return_tuple and is_top_level_module: + del_attribute_from_modules(self, "_is_top_level_module") + + return output + + return wrapper diff --git a/tests/utils/test_generic.py b/tests/utils/test_generic.py index 3eed30c16e..4af8d7c514 100644 --- a/tests/utils/test_generic.py +++ b/tests/utils/test_generic.py @@ -18,8 +18,11 @@ import warnings import numpy as np +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutput from transformers.testing_utils import require_flax, require_tf, require_torch from transformers.utils import ( + can_return_tuple, expand_dims, filter_out_non_signature_kwargs, flatten_dict, @@ -343,3 +346,119 @@ class ValidationDecoratorTester(unittest.TestCase): with self.assertWarns(UserWarning): kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4) self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3}) + + +@require_torch +class CanReturnTupleDecoratorTester(unittest.TestCase): + def _get_model(self, config, store_config=True, raise_in_forward=False): + # Simple model class for testing can_return_tuple decorator. + class SimpleTestModel(torch.nn.Module): + def __init__(self, config): + super().__init__() + if store_config: + self.config = config + + @can_return_tuple + def forward(self, x): + if raise_in_forward: + raise ValueError("Test error") + return BaseModelOutput( + last_hidden_state=x, + hidden_states=None, + attentions=None, + ) + + return SimpleTestModel(config) + + def test_decorator_eager(self): + """Test that the can_return_tuple decorator works with eager mode.""" + + # test nothing is set + config = PretrainedConfig() + model = self._get_model(config) + inputs = torch.tensor(10) + output = model(inputs) + self.assertIsInstance( + output, BaseModelOutput, "output should be a BaseModelOutput when return_dict is not set" + ) + + # test all explicit cases + for config_return_dict in [True, False, None]: + for return_dict in [True, False, None]: + config = PretrainedConfig(return_dict=config_return_dict) + model = self._get_model(config) + output = model(torch.tensor(10), return_dict=return_dict) + + expected_type = tuple if config_return_dict is False or return_dict is False else BaseModelOutput + message = f"output should be a {expected_type.__name__} when config.use_return_dict={config_return_dict} and return_dict={return_dict}" + self.assertIsInstance(output, expected_type, message) + + def test_decorator_compiled(self): + """Test that the can_return_tuple decorator works with compiled mode.""" + config = PretrainedConfig() + + # Output object + model = self._get_model(config) + compiled_model = torch.compile(model) + output = compiled_model(torch.tensor(10)) + self.assertIsInstance(output, BaseModelOutput) + + # Tuple output + model = self._get_model(config) + compiled_model = torch.compile(model) + output = compiled_model(torch.tensor(10), return_dict=False) + self.assertIsInstance(output, tuple) + + def test_decorator_torch_export(self): + """Test that the can_return_tuple decorator works with torch.export.""" + config = PretrainedConfig() + model = self._get_model(config) + torch.export.export(model, args=(torch.tensor(10),)) + + def test_decorator_torchscript(self): + """Test that the can_return_tuple decorator works with torch.jit.trace.""" + config = PretrainedConfig(return_dict=False) + model = self._get_model(config) + inputs = torch.tensor(10) + traced_module = torch.jit.trace(model, inputs) + output = traced_module(inputs) + self.assertIsInstance(output, tuple) + + def test_attribute_cleanup(self): + """Test that the `_is_top_level_module` attribute is removed after the forward call.""" + + config = PretrainedConfig(return_dict=False) + inputs = torch.tensor(10) + + # working case + model = self._get_model(config) + output = model(inputs) + + self.assertIsInstance(output, tuple) + for name, module in model.named_modules(): + self.assertFalse( + hasattr(module, "_is_top_level_module"), + f"Module `{name}` should not have `_is_top_level_module` attribute", + ) + + # model without config + no_config_model = self._get_model(config, store_config=False) + output = no_config_model(inputs) + + self.assertIsInstance(output, BaseModelOutput) + for name, module in no_config_model.named_modules(): + self.assertFalse( + hasattr(module, "_is_top_level_module"), + f"Module `{name}` should not have `_is_top_level_module` attribute", + ) + + # model with raise in forward + model_with_raise = self._get_model(config, raise_in_forward=True) + with self.assertRaises(ValueError): + model_with_raise(inputs) + + for name, module in model_with_raise.named_modules(): + self.assertFalse( + hasattr(module, "_is_top_level_module"), + f"Module `{name}` should not have `_is_top_level_module` attribute", + )