Refactor return_dict logic to remove complicated if/else paths (#36794)

* SAM

* CLIP

* SigLIP

* GOT-OCR2 (depends on SAM)

* SigLIP2 (depends on SigLIP)

* trigger tests

* Fix SAM

* Fix missed indexing, use named attributes

* Llama

* Aria

* Bamba

* Update llama: missed outputs return type

* (fixup) Aria

* DiffLlama

* Emu3

* Gemma

* Gemma2

* Paligemma

* Fix paligemma

* Gemma3

* GLM

* Helium

* JetMoe

* Jamba

* Mistral

* Mistral

* Mixtral

* Nemotron

* Olmo

* Olmo2

* Persimmon

* Phi

* Phi3

* PhiMoe

* Qwen2

* Qwen2_moe

* StableLM

* Starcoder2

* Add return_dict decorator

* SAM

* Update decorator: compile, export, trace - friendly

* Llama (decorator)

* SAM (decorator)

* Add decorator `can_return_tuple`

* Llama

* Update to decorator

* Update CLIP

* Update decorator to store `_is_top_level_module` in self

* Update decorator to correctly handle compile/export

* Remove is_torchdynamo_compiling constraint, all work fine with self attribute assignment

* Typing

* GPT NeoX

* Fixup

* Fix attribute Granite

* Fix return type mixtral

* Update Gemma3

* Fix Cohere amd Cohere2

* Fixup

* Fix corner case for Phi4, when activation is shared

* (fix-copies) deepseekv3, phi4

* Fixup

* Apply to qwen3/qwen3_moe

* Fix
This commit is contained in:
Pavel Iakubovskii
2025-03-31 16:23:37 +01:00
committed by GitHub
parent f304318f5f
commit a1e389e637
62 changed files with 943 additions and 1692 deletions

View File

@@ -18,8 +18,11 @@ import warnings
import numpy as np
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers.testing_utils import require_flax, require_tf, require_torch
from transformers.utils import (
can_return_tuple,
expand_dims,
filter_out_non_signature_kwargs,
flatten_dict,
@@ -343,3 +346,119 @@ class ValidationDecoratorTester(unittest.TestCase):
with self.assertWarns(UserWarning):
kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4)
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
@require_torch
class CanReturnTupleDecoratorTester(unittest.TestCase):
def _get_model(self, config, store_config=True, raise_in_forward=False):
# Simple model class for testing can_return_tuple decorator.
class SimpleTestModel(torch.nn.Module):
def __init__(self, config):
super().__init__()
if store_config:
self.config = config
@can_return_tuple
def forward(self, x):
if raise_in_forward:
raise ValueError("Test error")
return BaseModelOutput(
last_hidden_state=x,
hidden_states=None,
attentions=None,
)
return SimpleTestModel(config)
def test_decorator_eager(self):
"""Test that the can_return_tuple decorator works with eager mode."""
# test nothing is set
config = PretrainedConfig()
model = self._get_model(config)
inputs = torch.tensor(10)
output = model(inputs)
self.assertIsInstance(
output, BaseModelOutput, "output should be a BaseModelOutput when return_dict is not set"
)
# test all explicit cases
for config_return_dict in [True, False, None]:
for return_dict in [True, False, None]:
config = PretrainedConfig(return_dict=config_return_dict)
model = self._get_model(config)
output = model(torch.tensor(10), return_dict=return_dict)
expected_type = tuple if config_return_dict is False or return_dict is False else BaseModelOutput
message = f"output should be a {expected_type.__name__} when config.use_return_dict={config_return_dict} and return_dict={return_dict}"
self.assertIsInstance(output, expected_type, message)
def test_decorator_compiled(self):
"""Test that the can_return_tuple decorator works with compiled mode."""
config = PretrainedConfig()
# Output object
model = self._get_model(config)
compiled_model = torch.compile(model)
output = compiled_model(torch.tensor(10))
self.assertIsInstance(output, BaseModelOutput)
# Tuple output
model = self._get_model(config)
compiled_model = torch.compile(model)
output = compiled_model(torch.tensor(10), return_dict=False)
self.assertIsInstance(output, tuple)
def test_decorator_torch_export(self):
"""Test that the can_return_tuple decorator works with torch.export."""
config = PretrainedConfig()
model = self._get_model(config)
torch.export.export(model, args=(torch.tensor(10),))
def test_decorator_torchscript(self):
"""Test that the can_return_tuple decorator works with torch.jit.trace."""
config = PretrainedConfig(return_dict=False)
model = self._get_model(config)
inputs = torch.tensor(10)
traced_module = torch.jit.trace(model, inputs)
output = traced_module(inputs)
self.assertIsInstance(output, tuple)
def test_attribute_cleanup(self):
"""Test that the `_is_top_level_module` attribute is removed after the forward call."""
config = PretrainedConfig(return_dict=False)
inputs = torch.tensor(10)
# working case
model = self._get_model(config)
output = model(inputs)
self.assertIsInstance(output, tuple)
for name, module in model.named_modules():
self.assertFalse(
hasattr(module, "_is_top_level_module"),
f"Module `{name}` should not have `_is_top_level_module` attribute",
)
# model without config
no_config_model = self._get_model(config, store_config=False)
output = no_config_model(inputs)
self.assertIsInstance(output, BaseModelOutput)
for name, module in no_config_model.named_modules():
self.assertFalse(
hasattr(module, "_is_top_level_module"),
f"Module `{name}` should not have `_is_top_level_module` attribute",
)
# model with raise in forward
model_with_raise = self._get_model(config, raise_in_forward=True)
with self.assertRaises(ValueError):
model_with_raise(inputs)
for name, module in model_with_raise.named_modules():
self.assertFalse(
hasattr(module, "_is_top_level_module"),
f"Module `{name}` should not have `_is_top_level_module` attribute",
)