Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638)
* Add @dataclass to MaskFormerPixelDecoderOutput * Add dataclass check if subclass of ModelOutout * Use unittest assertRaises rather than pytest per contribution doc * Update src/transformers/utils/generic.py per suggested change Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput):
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskFormerPixelDecoderOutput(ModelOutput):
|
||||
"""
|
||||
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
|
||||
|
||||
@@ -20,7 +20,7 @@ import tempfile
|
||||
from collections import OrderedDict, UserDict
|
||||
from collections.abc import MutableMapping
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, ContextManager, List, Tuple
|
||||
|
||||
@@ -314,7 +314,26 @@ class ModelOutput(OrderedDict):
|
||||
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Subclasses of ModelOutput must use the @dataclass decorator
|
||||
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
|
||||
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
|
||||
# Just need to check that the current class is not ModelOutput
|
||||
is_modeloutput_subclass = self.__class__ != ModelOutput
|
||||
|
||||
if is_modeloutput_subclass and not is_dataclass(self):
|
||||
raise TypeError(
|
||||
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
|
||||
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Check the ModelOutput dataclass.
|
||||
|
||||
Only occurs if @dataclass decorator has been used.
|
||||
"""
|
||||
class_fields = fields(self)
|
||||
|
||||
# Safety and consistency checks
|
||||
|
||||
@@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase):
|
||||
|
||||
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
self.assertEqual(x, unflattened_x)
|
||||
|
||||
|
||||
class ModelOutputTestNoDataclass(ModelOutput):
|
||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||
|
||||
a: float
|
||||
b: Optional[float] = None
|
||||
c: Optional[float] = None
|
||||
|
||||
|
||||
class ModelOutputSubclassTester(unittest.TestCase):
|
||||
def test_direct_model_output(self):
|
||||
# Check that direct usage of ModelOutput instantiates without errors
|
||||
ModelOutput({"a": 1.1})
|
||||
|
||||
def test_subclass_no_dataclass(self):
|
||||
# Check that a subclass of ModelOutput without @dataclass is invalid
|
||||
# A valid subclass is inherently tested other unit tests above.
|
||||
with self.assertRaises(TypeError):
|
||||
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)
|
||||
|
||||
Reference in New Issue
Block a user