From 096838836d3c7a7d6782d152a4feabd777f2693d Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 12 Oct 2022 20:09:45 +0200 Subject: [PATCH] Throw an error if `getattribute_from_module` can't find anything (#19535) * return None to avoid recursive call * Give error * Give error * Add test * More tests * Quality Co-authored-by: ydshieh --- src/transformers/models/auto/auto_factory.py | 9 +++++++- tests/models/auto/test_modeling_auto.py | 24 +++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 8d3fabda47..04eb3feaac 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -555,7 +555,14 @@ def getattribute_from_module(module, attr): # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the # object at the top level. transformers_module = importlib.import_module("transformers") - return getattribute_from_module(transformers_module, attr) + + if module != transformers_module: + try: + return getattribute_from_module(transformers_module, attr) + except ValueError: + raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!") + else: + raise ValueError(f"Could not find {attr} in {transformers_module}!") class _LazyAutoMapping(OrderedDict): diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 91222c4d00..95df9365c6 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -17,9 +17,12 @@ import copy import sys import tempfile import unittest +from collections import OrderedDict from pathlib import Path -from transformers import BertConfig, is_torch_available +import pytest + +from transformers import BertConfig, GPT2Model, is_torch_available from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.testing_utils import ( DUMMY_UNKNOWN_IDENTIFIER, @@ -372,3 +375,22 @@ class AutoModelTest(unittest.TestCase): self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.other_request_count, 0) + + def test_attr_not_existing(self): + + from transformers.models.auto.auto_factory import _LazyAutoMapping + + _CONFIG_MAPPING_NAMES = OrderedDict([("bert", "BertConfig")]) + _MODEL_MAPPING_NAMES = OrderedDict([("bert", "GhostModel")]) + _MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES) + + with pytest.raises(ValueError, match=r"Could not find GhostModel neither in .* nor in .*!"): + _MODEL_MAPPING[BertConfig] + + _MODEL_MAPPING_NAMES = OrderedDict([("bert", "BertModel")]) + _MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES) + self.assertEqual(_MODEL_MAPPING[BertConfig], BertModel) + + _MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")]) + _MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES) + self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model)