Refactor return_dict logic to remove complicated if/else paths (#36794)
* SAM * CLIP * SigLIP * GOT-OCR2 (depends on SAM) * SigLIP2 (depends on SigLIP) * trigger tests * Fix SAM * Fix missed indexing, use named attributes * Llama * Aria * Bamba * Update llama: missed outputs return type * (fixup) Aria * DiffLlama * Emu3 * Gemma * Gemma2 * Paligemma * Fix paligemma * Gemma3 * GLM * Helium * JetMoe * Jamba * Mistral * Mistral * Mixtral * Nemotron * Olmo * Olmo2 * Persimmon * Phi * Phi3 * PhiMoe * Qwen2 * Qwen2_moe * StableLM * Starcoder2 * Add return_dict decorator * SAM * Update decorator: compile, export, trace - friendly * Llama (decorator) * SAM (decorator) * Add decorator `can_return_tuple` * Llama * Update to decorator * Update CLIP * Update decorator to store `_is_top_level_module` in self * Update decorator to correctly handle compile/export * Remove is_torchdynamo_compiling constraint, all work fine with self attribute assignment * Typing * GPT NeoX * Fixup * Fix attribute Granite * Fix return type mixtral * Update Gemma3 * Fix Cohere amd Cohere2 * Fixup * Fix corner case for Phi4, when activation is shared * (fix-copies) deepseekv3, phi4 * Fixup * Apply to qwen3/qwen3_moe * Fix
This commit is contained in:
committed by
GitHub
parent
f304318f5f
commit
a1e389e637
@@ -18,8 +18,11 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.testing_utils import require_flax, require_tf, require_torch
|
||||
from transformers.utils import (
|
||||
can_return_tuple,
|
||||
expand_dims,
|
||||
filter_out_non_signature_kwargs,
|
||||
flatten_dict,
|
||||
@@ -343,3 +346,119 @@ class ValidationDecoratorTester(unittest.TestCase):
|
||||
with self.assertWarns(UserWarning):
|
||||
kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
|
||||
|
||||
|
||||
@require_torch
|
||||
class CanReturnTupleDecoratorTester(unittest.TestCase):
|
||||
def _get_model(self, config, store_config=True, raise_in_forward=False):
|
||||
# Simple model class for testing can_return_tuple decorator.
|
||||
class SimpleTestModel(torch.nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if store_config:
|
||||
self.config = config
|
||||
|
||||
@can_return_tuple
|
||||
def forward(self, x):
|
||||
if raise_in_forward:
|
||||
raise ValueError("Test error")
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=x,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
return SimpleTestModel(config)
|
||||
|
||||
def test_decorator_eager(self):
|
||||
"""Test that the can_return_tuple decorator works with eager mode."""
|
||||
|
||||
# test nothing is set
|
||||
config = PretrainedConfig()
|
||||
model = self._get_model(config)
|
||||
inputs = torch.tensor(10)
|
||||
output = model(inputs)
|
||||
self.assertIsInstance(
|
||||
output, BaseModelOutput, "output should be a BaseModelOutput when return_dict is not set"
|
||||
)
|
||||
|
||||
# test all explicit cases
|
||||
for config_return_dict in [True, False, None]:
|
||||
for return_dict in [True, False, None]:
|
||||
config = PretrainedConfig(return_dict=config_return_dict)
|
||||
model = self._get_model(config)
|
||||
output = model(torch.tensor(10), return_dict=return_dict)
|
||||
|
||||
expected_type = tuple if config_return_dict is False or return_dict is False else BaseModelOutput
|
||||
message = f"output should be a {expected_type.__name__} when config.use_return_dict={config_return_dict} and return_dict={return_dict}"
|
||||
self.assertIsInstance(output, expected_type, message)
|
||||
|
||||
def test_decorator_compiled(self):
|
||||
"""Test that the can_return_tuple decorator works with compiled mode."""
|
||||
config = PretrainedConfig()
|
||||
|
||||
# Output object
|
||||
model = self._get_model(config)
|
||||
compiled_model = torch.compile(model)
|
||||
output = compiled_model(torch.tensor(10))
|
||||
self.assertIsInstance(output, BaseModelOutput)
|
||||
|
||||
# Tuple output
|
||||
model = self._get_model(config)
|
||||
compiled_model = torch.compile(model)
|
||||
output = compiled_model(torch.tensor(10), return_dict=False)
|
||||
self.assertIsInstance(output, tuple)
|
||||
|
||||
def test_decorator_torch_export(self):
|
||||
"""Test that the can_return_tuple decorator works with torch.export."""
|
||||
config = PretrainedConfig()
|
||||
model = self._get_model(config)
|
||||
torch.export.export(model, args=(torch.tensor(10),))
|
||||
|
||||
def test_decorator_torchscript(self):
|
||||
"""Test that the can_return_tuple decorator works with torch.jit.trace."""
|
||||
config = PretrainedConfig(return_dict=False)
|
||||
model = self._get_model(config)
|
||||
inputs = torch.tensor(10)
|
||||
traced_module = torch.jit.trace(model, inputs)
|
||||
output = traced_module(inputs)
|
||||
self.assertIsInstance(output, tuple)
|
||||
|
||||
def test_attribute_cleanup(self):
|
||||
"""Test that the `_is_top_level_module` attribute is removed after the forward call."""
|
||||
|
||||
config = PretrainedConfig(return_dict=False)
|
||||
inputs = torch.tensor(10)
|
||||
|
||||
# working case
|
||||
model = self._get_model(config)
|
||||
output = model(inputs)
|
||||
|
||||
self.assertIsInstance(output, tuple)
|
||||
for name, module in model.named_modules():
|
||||
self.assertFalse(
|
||||
hasattr(module, "_is_top_level_module"),
|
||||
f"Module `{name}` should not have `_is_top_level_module` attribute",
|
||||
)
|
||||
|
||||
# model without config
|
||||
no_config_model = self._get_model(config, store_config=False)
|
||||
output = no_config_model(inputs)
|
||||
|
||||
self.assertIsInstance(output, BaseModelOutput)
|
||||
for name, module in no_config_model.named_modules():
|
||||
self.assertFalse(
|
||||
hasattr(module, "_is_top_level_module"),
|
||||
f"Module `{name}` should not have `_is_top_level_module` attribute",
|
||||
)
|
||||
|
||||
# model with raise in forward
|
||||
model_with_raise = self._get_model(config, raise_in_forward=True)
|
||||
with self.assertRaises(ValueError):
|
||||
model_with_raise(inputs)
|
||||
|
||||
for name, module in model_with_raise.named_modules():
|
||||
self.assertFalse(
|
||||
hasattr(module, "_is_top_level_module"),
|
||||
f"Module `{name}` should not have `_is_top_level_module` attribute",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user