From ae736163d0d7a3a167ff0df3bf6c824437bbba2a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 11 Sep 2020 12:01:33 -0400 Subject: [PATCH] Add tests and fix various bugs in ModelOutput (#7073) * Add tests and fix various bugs in ModelOutput * Update tests/test_model_output.py Co-authored-by: Patrick von Platen Co-authored-by: Patrick von Platen --- src/transformers/file_utils.py | 14 +++++ tests/test_model_output.py | 103 +++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 tests/test_model_output.py diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index beef7e833b..07ce03be1b 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -988,6 +988,8 @@ class ModelOutput(OrderedDict): setattr(self, element[0], element[1]) if element[1] is not None: self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field else: for field in class_fields: v = getattr(self, field.name) @@ -1013,6 +1015,18 @@ class ModelOutput(OrderedDict): else: return self.to_tuple()[k] + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not ``None``. diff --git a/tests/test_model_output.py b/tests/test_model_output.py new file mode 100644 index 0000000000..a5160566e6 --- /dev/null +++ b/tests/test_model_output.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright 2020 The Hugging Face Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from dataclasses import dataclass +from typing import Optional + +from transformers.file_utils import ModelOutput + + +@dataclass +class ModelOutputTest(ModelOutput): + a: float + b: Optional[float] = None + c: Optional[float] = None + + +class ModelOutputTester(unittest.TestCase): + def test_get_attributes(self): + x = ModelOutputTest(a=30) + self.assertEqual(x.a, 30) + self.assertIsNone(x.b) + self.assertIsNone(x.c) + with self.assertRaises(AttributeError): + _ = x.d + + def test_index_with_ints_and_slices(self): + x = ModelOutputTest(a=30, b=10) + self.assertEqual(x[0], 30) + self.assertEqual(x[1], 10) + self.assertEqual(x[:2], (30, 10)) + self.assertEqual(x[:], (30, 10)) + + x = ModelOutputTest(a=30, c=10) + self.assertEqual(x[0], 30) + self.assertEqual(x[1], 10) + self.assertEqual(x[:2], (30, 10)) + self.assertEqual(x[:], (30, 10)) + + def test_index_with_strings(self): + x = ModelOutputTest(a=30, b=10) + self.assertEqual(x["a"], 30) + self.assertEqual(x["b"], 10) + with self.assertRaises(KeyError): + _ = x["c"] + + x = ModelOutputTest(a=30, c=10) + self.assertEqual(x["a"], 30) + self.assertEqual(x["c"], 10) + with self.assertRaises(KeyError): + _ = x["b"] + + def test_dict_like_properties(self): + x = ModelOutputTest(a=30) + self.assertEqual(list(x.keys()), ["a"]) + self.assertEqual(list(x.values()), [30]) + self.assertEqual(list(x.items()), [("a", 30)]) + self.assertEqual(list(x), ["a"]) + + x = ModelOutputTest(a=30, b=10) + self.assertEqual(list(x.keys()), ["a", "b"]) + self.assertEqual(list(x.values()), [30, 10]) + self.assertEqual(list(x.items()), [("a", 30), ("b", 10)]) + self.assertEqual(list(x), ["a", "b"]) + + x = ModelOutputTest(a=30, c=10) + self.assertEqual(list(x.keys()), ["a", "c"]) + self.assertEqual(list(x.values()), [30, 10]) + self.assertEqual(list(x.items()), [("a", 30), ("c", 10)]) + self.assertEqual(list(x), ["a", "c"]) + + with self.assertRaises(Exception): + x = x.update({"d": 20}) + with self.assertRaises(Exception): + del x["a"] + with self.assertRaises(Exception): + _ = x.pop("a") + with self.assertRaises(Exception): + _ = x.setdefault("d", 32) + + def test_set_attributes(self): + x = ModelOutputTest(a=30) + x.a = 10 + self.assertEqual(x.a, 10) + self.assertEqual(x["a"], 10) + + def test_set_keys(self): + x = ModelOutputTest(a=30) + x["a"] = 10 + self.assertEqual(x.a, 10) + self.assertEqual(x["a"], 10)