From 0454e4bd8baa709fe6c44426e9feb3f43baffa0b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 10 Aug 2021 12:20:04 +0200 Subject: [PATCH] Fix ModelOutput instantiation form dictionaries (#13067) * Fix ModelOutput instantiation form dictionaries * Style --- src/transformers/file_utils.py | 12 ++++++++---- tests/test_model_output.py | 6 ++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 1c18fb6070..a395ec275f 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1850,11 +1850,15 @@ class ModelOutput(OrderedDict): other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and not is_tensor(first_field): - try: - iterator = iter(first_field) + if isinstance(first_field, dict): + iterator = first_field.items() first_field_iterator = True - except TypeError: - first_field_iterator = False + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False # if we provided an iterator as first field and the iterator is a (key, value) iterator # set the associated fields diff --git a/tests/test_model_output.py b/tests/test_model_output.py index a5160566e6..381f9760a5 100644 --- a/tests/test_model_output.py +++ b/tests/test_model_output.py @@ -101,3 +101,9 @@ class ModelOutputTester(unittest.TestCase): x["a"] = 10 self.assertEqual(x.a, 10) self.assertEqual(x["a"], 10) + + def test_instantiate_from_dict(self): + x = ModelOutputTest({"a": 30, "b": 10}) + self.assertEqual(list(x.keys()), ["a", "b"]) + self.assertEqual(x.a, 30) + self.assertEqual(x.b, 10)