Refactor return_dict logic to remove complicated if/else paths (#36794)

* SAM

* CLIP

* SigLIP

* GOT-OCR2 (depends on SAM)

* SigLIP2 (depends on SigLIP)

* trigger tests

* Fix SAM

* Fix missed indexing, use named attributes

* Llama

* Aria

* Bamba

* Update llama: missed outputs return type

* (fixup) Aria

* DiffLlama

* Emu3

* Gemma

* Gemma2

* Paligemma

* Fix paligemma

* Gemma3

* GLM

* Helium

* JetMoe

* Jamba

* Mistral

* Mistral

* Mixtral

* Nemotron

* Olmo

* Olmo2

* Persimmon

* Phi

* Phi3

* PhiMoe

* Qwen2

* Qwen2_moe

* StableLM

* Starcoder2

* Add return_dict decorator

* SAM

* Update decorator: compile, export, trace - friendly

* Llama (decorator)

* SAM (decorator)

* Add decorator `can_return_tuple`

* Llama

* Update to decorator

* Update CLIP

* Update decorator to store `_is_top_level_module` in self

* Update decorator to correctly handle compile/export

* Remove is_torchdynamo_compiling constraint, all work fine with self attribute assignment

* Typing

* GPT NeoX

* Fixup

* Fix attribute Granite

* Fix return type mixtral

* Update Gemma3

* Fix Cohere amd Cohere2

* Fixup

* Fix corner case for Phi4, when activation is shared

* (fix-copies) deepseekv3, phi4

* Fixup

* Apply to qwen3/qwen3_moe

* Fix
This commit is contained in:
Pavel Iakubovskii
2025-03-31 16:23:37 +01:00
committed by GitHub
parent f304318f5f
commit a1e389e637
62 changed files with 943 additions and 1692 deletions

View File

@@ -35,6 +35,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@@ -895,6 +896,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
def forward(
self,
@@ -906,16 +908,14 @@ class AriaTextModel(AriaTextPreTrainedModel):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -998,13 +998,12 @@ class AriaTextModel(AriaTextPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
@@ -1182,6 +1181,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1196,11 +1196,10 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1236,10 +1235,9 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -1248,12 +1246,11 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
@@ -1262,10 +1259,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
@@ -1445,6 +1438,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
@@ -1461,11 +1455,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
) -> AriaCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1531,7 +1524,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
@@ -1562,7 +1554,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs: CausalLMOutputWithPast = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -1570,12 +1562,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
logits_to_keep=logits_to_keep,
cache_position=cache_position,
)
logits = outputs[0]
logits = outputs.logits
loss = None
if labels is not None:
@@ -1583,10 +1574,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return AriaCausalLMOutputWithPast(
loss=loss,
logits=logits,

View File

@@ -33,6 +33,7 @@ from ...image_utils import (
valid_images,
validate_preprocess_arguments,
)
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import (
@@ -43,6 +44,7 @@ from ...utils import (
TensorType,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
@@ -1416,6 +1418,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
@@ -1432,11 +1435,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
) -> AriaCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1502,7 +1504,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
@@ -1533,7 +1534,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs: CausalLMOutputWithPast = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -1541,12 +1542,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
logits_to_keep=logits_to_keep,
cache_position=cache_position,
)
logits = outputs[0]
logits = outputs.logits
loss = None
if labels is not None:
@@ -1554,10 +1554,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return AriaCausalLMOutputWithPast(
loss=loss,
logits=logits,