From afce73bd9d891b55dcb8d4d875d17718ffa01ff0 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 23 Nov 2022 15:09:21 -0500 Subject: [PATCH] Fix ModelOutput instantiation when there is only one tuple (#20416) --- src/transformers/utils/generic.py | 10 +++++++++- tests/utils/test_model_output.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 1d9201b95d..b2725b3148 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -227,12 +227,20 @@ class ModelOutput(OrderedDict): # if we provided an iterator as first field and the iterator is a (key, value) iterator # set the associated fields if first_field_iterator: - for element in iterator: + for idx, element in enumerate(iterator): if ( not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str) ): + if idx == 0: + # If we do not have an iterator of key/values, set it as attribute + self[class_fields[0].name] = first_field + else: + # If we have a mixed iterator, raise an error + raise ValueError( + f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." + ) break setattr(self, element[0], element[1]) if element[1] is not None: diff --git a/tests/utils/test_model_output.py b/tests/utils/test_model_output.py index 9fe3e32a99..20ff5ceba8 100644 --- a/tests/utils/test_model_output.py +++ b/tests/utils/test_model_output.py @@ -107,3 +107,16 @@ class ModelOutputTester(unittest.TestCase): self.assertEqual(list(x.keys()), ["a", "b"]) self.assertEqual(x.a, 30) self.assertEqual(x.b, 10) + + def test_instantiate_from_iterator(self): + x = ModelOutputTest([("a", 30), ("b", 10)]) + self.assertEqual(list(x.keys()), ["a", "b"]) + self.assertEqual(x.a, 30) + self.assertEqual(x.b, 10) + + with self.assertRaises(ValueError): + _ = ModelOutputTest([("a", 30), (10, 10)]) + + x = ModelOutputTest(a=(30, 30)) + self.assertEqual(list(x.keys()), ["a"]) + self.assertEqual(x.a, (30, 30))