Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638)
* Add @dataclass to MaskFormerPixelDecoderOutput * Add dataclass check if subclass of ModelOutout * Use unittest assertRaises rather than pytest per contribution doc * Update src/transformers/utils/generic.py per suggested change Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput):
|
|||||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class MaskFormerPixelDecoderOutput(ModelOutput):
|
class MaskFormerPixelDecoderOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
|
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import tempfile
|
|||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from dataclasses import fields
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, ContextManager, List, Tuple
|
from typing import Any, ContextManager, List, Tuple
|
||||||
|
|
||||||
@@ -314,7 +314,26 @@ class ModelOutput(OrderedDict):
|
|||||||
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Subclasses of ModelOutput must use the @dataclass decorator
|
||||||
|
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
|
||||||
|
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
|
||||||
|
# Just need to check that the current class is not ModelOutput
|
||||||
|
is_modeloutput_subclass = self.__class__ != ModelOutput
|
||||||
|
|
||||||
|
if is_modeloutput_subclass and not is_dataclass(self):
|
||||||
|
raise TypeError(
|
||||||
|
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
|
||||||
|
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
"""Check the ModelOutput dataclass.
|
||||||
|
|
||||||
|
Only occurs if @dataclass decorator has been used.
|
||||||
|
"""
|
||||||
class_fields = fields(self)
|
class_fields = fields(self)
|
||||||
|
|
||||||
# Safety and consistency checks
|
# Safety and consistency checks
|
||||||
|
|||||||
@@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase):
|
|||||||
|
|
||||||
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||||
self.assertEqual(x, unflattened_x)
|
self.assertEqual(x, unflattened_x)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutputTestNoDataclass(ModelOutput):
|
||||||
|
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||||
|
|
||||||
|
a: float
|
||||||
|
b: Optional[float] = None
|
||||||
|
c: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutputSubclassTester(unittest.TestCase):
|
||||||
|
def test_direct_model_output(self):
|
||||||
|
# Check that direct usage of ModelOutput instantiates without errors
|
||||||
|
ModelOutput({"a": 1.1})
|
||||||
|
|
||||||
|
def test_subclass_no_dataclass(self):
|
||||||
|
# Check that a subclass of ModelOutput without @dataclass is invalid
|
||||||
|
# A valid subclass is inherently tested other unit tests above.
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)
|
||||||
|
|||||||
Reference in New Issue
Block a user