Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638)

* Add @dataclass to MaskFormerPixelDecoderOutput

* Add dataclass check if subclass of ModelOutout

* Use unittest assertRaises rather than pytest per contribution doc

* Update src/transformers/utils/generic.py per suggested change

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Craig Chan
2023-09-14 05:30:49 -04:00
committed by GitHub
parent 05de038f3d
commit d7bd325b5a
3 changed files with 41 additions and 1 deletions

View File

@@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput):
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerPixelDecoderOutput(ModelOutput):
"""
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state

View File

@@ -20,7 +20,7 @@ import tempfile
from collections import OrderedDict, UserDict
from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager
from dataclasses import fields
from dataclasses import fields, is_dataclass
from enum import Enum
from typing import Any, ContextManager, List, Tuple
@@ -314,7 +314,26 @@ class ModelOutput(OrderedDict):
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Subclasses of ModelOutput must use the @dataclass decorator
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
# Just need to check that the current class is not ModelOutput
is_modeloutput_subclass = self.__class__ != ModelOutput
if is_modeloutput_subclass and not is_dataclass(self):
raise TypeError(
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
)
def __post_init__(self):
"""Check the ModelOutput dataclass.
Only occurs if @dataclass decorator has been used.
"""
class_fields = fields(self)
# Safety and consistency checks

View File

@@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase):
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)
class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
a: float
b: Optional[float] = None
c: Optional[float] = None
class ModelOutputSubclassTester(unittest.TestCase):
def test_direct_model_output(self):
# Check that direct usage of ModelOutput instantiates without errors
ModelOutput({"a": 1.1})
def test_subclass_no_dataclass(self):
# Check that a subclass of ModelOutput without @dataclass is invalid
# A valid subclass is inherently tested other unit tests above.
with self.assertRaises(TypeError):
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)