Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638)

* Add @dataclass to MaskFormerPixelDecoderOutput

* Add dataclass check if subclass of ModelOutout

* Use unittest assertRaises rather than pytest per contribution doc

* Update src/transformers/utils/generic.py per suggested change

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Craig Chan
2023-09-14 05:30:49 -04:00
committed by GitHub
parent 05de038f3d
commit d7bd325b5a
3 changed files with 41 additions and 1 deletions

View File

@@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase):
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)
class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
a: float
b: Optional[float] = None
c: Optional[float] = None
class ModelOutputSubclassTester(unittest.TestCase):
def test_direct_model_output(self):
# Check that direct usage of ModelOutput instantiates without errors
ModelOutput({"a": 1.1})
def test_subclass_no_dataclass(self):
# Check that a subclass of ModelOutput without @dataclass is invalid
# A valid subclass is inherently tested other unit tests above.
with self.assertRaises(TypeError):
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)