Throw an error if getattribute_from_module can't find anything (#19535)

* return None to avoid recursive call

* Give error

* Give error

* Add test

* More tests

* Quality

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-10-12 20:09:45 +02:00
committed by GitHub
parent 383ad81e68
commit 096838836d
2 changed files with 31 additions and 2 deletions

View File

@@ -555,7 +555,14 @@ def getattribute_from_module(module, attr):
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level. # object at the top level.
transformers_module = importlib.import_module("transformers") transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)
if module != transformers_module:
try:
return getattribute_from_module(transformers_module, attr)
except ValueError:
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
else:
raise ValueError(f"Could not find {attr} in {transformers_module}!")
class _LazyAutoMapping(OrderedDict): class _LazyAutoMapping(OrderedDict):

View File

@@ -17,9 +17,12 @@ import copy
import sys import sys
import tempfile import tempfile
import unittest import unittest
from collections import OrderedDict
from pathlib import Path from pathlib import Path
from transformers import BertConfig, is_torch_available import pytest
from transformers import BertConfig, GPT2Model, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import ( from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER, DUMMY_UNKNOWN_IDENTIFIER,
@@ -372,3 +375,22 @@ class AutoModelTest(unittest.TestCase):
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.other_request_count, 0)
def test_attr_not_existing(self):
from transformers.models.auto.auto_factory import _LazyAutoMapping
_CONFIG_MAPPING_NAMES = OrderedDict([("bert", "BertConfig")])
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GhostModel")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
with pytest.raises(ValueError, match=r"Could not find GhostModel neither in .* nor in .*!"):
_MODEL_MAPPING[BertConfig]
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "BertModel")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
self.assertEqual(_MODEL_MAPPING[BertConfig], BertModel)
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model)