Throw an error if getattribute_from_module can't find anything (#19535)
* return None to avoid recursive call * Give error * Give error * Add test * More tests * Quality Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -555,7 +555,14 @@ def getattribute_from_module(module, attr):
|
||||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
|
||||
# object at the top level.
|
||||
transformers_module = importlib.import_module("transformers")
|
||||
|
||||
if module != transformers_module:
|
||||
try:
|
||||
return getattribute_from_module(transformers_module, attr)
|
||||
except ValueError:
|
||||
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
|
||||
else:
|
||||
raise ValueError(f"Could not find {attr} in {transformers_module}!")
|
||||
|
||||
|
||||
class _LazyAutoMapping(OrderedDict):
|
||||
|
||||
@@ -17,9 +17,12 @@ import copy
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import BertConfig, is_torch_available
|
||||
import pytest
|
||||
|
||||
from transformers import BertConfig, GPT2Model, is_torch_available
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
@@ -372,3 +375,22 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
def test_attr_not_existing(self):
|
||||
|
||||
from transformers.models.auto.auto_factory import _LazyAutoMapping
|
||||
|
||||
_CONFIG_MAPPING_NAMES = OrderedDict([("bert", "BertConfig")])
|
||||
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GhostModel")])
|
||||
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Could not find GhostModel neither in .* nor in .*!"):
|
||||
_MODEL_MAPPING[BertConfig]
|
||||
|
||||
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "BertModel")])
|
||||
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||
self.assertEqual(_MODEL_MAPPING[BertConfig], BertModel)
|
||||
|
||||
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")])
|
||||
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||
self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model)
|
||||
|
||||
Reference in New Issue
Block a user