From d7bd325b5a44054341acc536339adab9ef8e8bb2 Mon Sep 17 00:00:00 2001 From: Craig Chan <46288912+rachthree@users.noreply.github.com> Date: Thu, 14 Sep 2023 05:30:49 -0400 Subject: [PATCH] Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638) * Add @dataclass to MaskFormerPixelDecoderOutput * Add dataclass check if subclass of ModelOutout * Use unittest assertRaises rather than pytest per contribution doc * Update src/transformers/utils/generic.py per suggested change Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/maskformer/modeling_maskformer.py | 1 + src/transformers/utils/generic.py | 21 ++++++++++++++++++- tests/utils/test_model_output.py | 20 ++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 69caeeedc0..87b91ed64b 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput): decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None +@dataclass class MaskFormerPixelDecoderOutput(ModelOutput): """ MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index b92a237282..4ba379491e 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -20,7 +20,7 @@ import tempfile from collections import OrderedDict, UserDict from collections.abc import MutableMapping from contextlib import ExitStack, contextmanager -from dataclasses import fields +from dataclasses import fields, is_dataclass from enum import Enum from typing import Any, ContextManager, List, Tuple @@ -314,7 +314,26 @@ class ModelOutput(OrderedDict): lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Subclasses of ModelOutput must use the @dataclass decorator + # This check is done in __init__ because the @dataclass decorator operates after __init_subclass__ + # issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed + # Just need to check that the current class is not ModelOutput + is_modeloutput_subclass = self.__class__ != ModelOutput + + if is_modeloutput_subclass and not is_dataclass(self): + raise TypeError( + f"{self.__module__}.{self.__class__.__name__} is not a dataclasss." + " This is a subclass of ModelOutput and so must use the @dataclass decorator." + ) + def __post_init__(self): + """Check the ModelOutput dataclass. + + Only occurs if @dataclass decorator has been used. + """ class_fields = fields(self) # Safety and consistency checks diff --git a/tests/utils/test_model_output.py b/tests/utils/test_model_output.py index b415b6c2ef..abfc5427cf 100644 --- a/tests/utils/test_model_output.py +++ b/tests/utils/test_model_output.py @@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase): unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) self.assertEqual(x, unflattened_x) + + +class ModelOutputTestNoDataclass(ModelOutput): + """Invalid test subclass of ModelOutput where @dataclass decorator is not used""" + + a: float + b: Optional[float] = None + c: Optional[float] = None + + +class ModelOutputSubclassTester(unittest.TestCase): + def test_direct_model_output(self): + # Check that direct usage of ModelOutput instantiates without errors + ModelOutput({"a": 1.1}) + + def test_subclass_no_dataclass(self): + # Check that a subclass of ModelOutput without @dataclass is invalid + # A valid subclass is inherently tested other unit tests above. + with self.assertRaises(TypeError): + ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)