Throw an error if getattribute_from_module can't find anything (#19535)
* return None to avoid recursive call * Give error * Give error * Add test * More tests * Quality Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -555,7 +555,14 @@ def getattribute_from_module(module, attr):
|
|||||||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
|
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
|
||||||
# object at the top level.
|
# object at the top level.
|
||||||
transformers_module = importlib.import_module("transformers")
|
transformers_module = importlib.import_module("transformers")
|
||||||
return getattribute_from_module(transformers_module, attr)
|
|
||||||
|
if module != transformers_module:
|
||||||
|
try:
|
||||||
|
return getattribute_from_module(transformers_module, attr)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not find {attr} in {transformers_module}!")
|
||||||
|
|
||||||
|
|
||||||
class _LazyAutoMapping(OrderedDict):
|
class _LazyAutoMapping(OrderedDict):
|
||||||
|
|||||||
@@ -17,9 +17,12 @@ import copy
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers import BertConfig, is_torch_available
|
import pytest
|
||||||
|
|
||||||
|
from transformers import BertConfig, GPT2Model, is_torch_available
|
||||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
DUMMY_UNKNOWN_IDENTIFIER,
|
DUMMY_UNKNOWN_IDENTIFIER,
|
||||||
@@ -372,3 +375,22 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(counter.get_request_count, 0)
|
self.assertEqual(counter.get_request_count, 0)
|
||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter.head_request_count, 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|
||||||
|
def test_attr_not_existing(self):
|
||||||
|
|
||||||
|
from transformers.models.auto.auto_factory import _LazyAutoMapping
|
||||||
|
|
||||||
|
_CONFIG_MAPPING_NAMES = OrderedDict([("bert", "BertConfig")])
|
||||||
|
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GhostModel")])
|
||||||
|
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"Could not find GhostModel neither in .* nor in .*!"):
|
||||||
|
_MODEL_MAPPING[BertConfig]
|
||||||
|
|
||||||
|
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "BertModel")])
|
||||||
|
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||||
|
self.assertEqual(_MODEL_MAPPING[BertConfig], BertModel)
|
||||||
|
|
||||||
|
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")])
|
||||||
|
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||||
|
self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model)
|
||||||
|
|||||||
Reference in New Issue
Block a user