Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638)
* Add @dataclass to MaskFormerPixelDecoderOutput * Add dataclass check if subclass of ModelOutout * Use unittest assertRaises rather than pytest per contribution doc * Update src/transformers/utils/generic.py per suggested change Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase):
|
||||
|
||||
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
self.assertEqual(x, unflattened_x)
|
||||
|
||||
|
||||
class ModelOutputTestNoDataclass(ModelOutput):
|
||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||
|
||||
a: float
|
||||
b: Optional[float] = None
|
||||
c: Optional[float] = None
|
||||
|
||||
|
||||
class ModelOutputSubclassTester(unittest.TestCase):
|
||||
def test_direct_model_output(self):
|
||||
# Check that direct usage of ModelOutput instantiates without errors
|
||||
ModelOutput({"a": 1.1})
|
||||
|
||||
def test_subclass_no_dataclass(self):
|
||||
# Check that a subclass of ModelOutput without @dataclass is invalid
|
||||
# A valid subclass is inherently tested other unit tests above.
|
||||
with self.assertRaises(TypeError):
|
||||
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)
|
||||
|
||||
Reference in New Issue
Block a user