Refactor return_dict logic to remove complicated if/else paths (#36794)
* SAM * CLIP * SigLIP * GOT-OCR2 (depends on SAM) * SigLIP2 (depends on SigLIP) * trigger tests * Fix SAM * Fix missed indexing, use named attributes * Llama * Aria * Bamba * Update llama: missed outputs return type * (fixup) Aria * DiffLlama * Emu3 * Gemma * Gemma2 * Paligemma * Fix paligemma * Gemma3 * GLM * Helium * JetMoe * Jamba * Mistral * Mistral * Mixtral * Nemotron * Olmo * Olmo2 * Persimmon * Phi * Phi3 * PhiMoe * Qwen2 * Qwen2_moe * StableLM * Starcoder2 * Add return_dict decorator * SAM * Update decorator: compile, export, trace - friendly * Llama (decorator) * SAM (decorator) * Add decorator `can_return_tuple` * Llama * Update to decorator * Update CLIP * Update decorator to store `_is_top_level_module` in self * Update decorator to correctly handle compile/export * Remove is_torchdynamo_compiling constraint, all work fine with self attribute assignment * Typing * GPT NeoX * Fixup * Fix attribute Granite * Fix return type mixtral * Update Gemma3 * Fix Cohere amd Cohere2 * Fixup * Fix corner case for Phi4, when activation is shared * (fix-copies) deepseekv3, phi4 * Fixup * Apply to qwen3/qwen3_moe * Fix
This commit is contained in:
committed by
GitHub
parent
f304318f5f
commit
a1e389e637
@@ -35,6 +35,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@@ -895,6 +896,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@@ -906,16 +908,14 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
@@ -998,13 +998,12 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
output = BaseModelOutputWithPast(
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
@@ -1182,6 +1181,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@@ -1196,11 +1196,10 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
@@ -1236,10 +1235,9 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -1248,12 +1246,11 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
@@ -1262,10 +1259,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@@ -1445,6 +1438,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
|
||||
@@ -1461,11 +1455,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
||||
) -> AriaCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
@@ -1531,7 +1524,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
@@ -1562,7 +1554,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
outputs: CausalLMOutputWithPast = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
@@ -1570,12 +1562,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
logits_to_keep=logits_to_keep,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@@ -1583,10 +1574,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return AriaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
||||
@@ -33,6 +33,7 @@ from ...image_utils import (
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils import (
|
||||
@@ -43,6 +44,7 @@ from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -1416,6 +1418,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
|
||||
@@ -1432,11 +1435,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
||||
) -> AriaCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
@@ -1502,7 +1504,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
@@ -1533,7 +1534,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
outputs: CausalLMOutputWithPast = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
@@ -1541,12 +1542,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
logits_to_keep=logits_to_keep,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@@ -1554,10 +1554,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return AriaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
||||
Reference in New Issue
Block a user