Throw an error if getattribute_from_module can't find anything (#19535)
* return None to avoid recursive call * Give error * Give error * Add test * More tests * Quality Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -17,9 +17,12 @@ import copy
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import BertConfig, is_torch_available
|
||||
import pytest
|
||||
|
||||
from transformers import BertConfig, GPT2Model, is_torch_available
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
@@ -372,3 +375,22 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
def test_attr_not_existing(self):
|
||||
|
||||
from transformers.models.auto.auto_factory import _LazyAutoMapping
|
||||
|
||||
_CONFIG_MAPPING_NAMES = OrderedDict([("bert", "BertConfig")])
|
||||
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GhostModel")])
|
||||
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Could not find GhostModel neither in .* nor in .*!"):
|
||||
_MODEL_MAPPING[BertConfig]
|
||||
|
||||
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "BertModel")])
|
||||
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||
self.assertEqual(_MODEL_MAPPING[BertConfig], BertModel)
|
||||
|
||||
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")])
|
||||
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
|
||||
self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model)
|
||||
|
||||
Reference in New Issue
Block a user