Replace error by warning when loading an architecture in another (#11207)
* Replace error by warning when loading an architecture in another * Style * Style again * Add a test * Adapt old test
This commit is contained in:
@@ -399,10 +399,11 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
if config_dict.get("model_type", False) and hasattr(cls, "model_type"):
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||||
assert (
|
logger.warn(
|
||||||
config_dict["model_type"] == cls.model_type
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
), f"You tried to initiate a model of type '{cls.model_type}' with a pretrained model of type '{config_dict['model_type']}'"
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
)
|
||||||
|
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -231,13 +231,7 @@ class BertGenerationEncoderTester:
|
|||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
input_mask,
|
|
||||||
token_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -259,6 +253,11 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_as_bert(self):
|
||||||
|
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config.model_type = "bert"
|
||||||
|
self.model_tester.create_and_check_model(config, input_ids, input_mask, token_labels)
|
||||||
|
|
||||||
def test_model_as_decoder(self):
|
def test_model_as_decoder(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||||
|
|||||||
@@ -22,10 +22,10 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available, logging
|
||||||
from transformers.file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_multi_gpu, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -1296,6 +1296,7 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
with self.assertRaises(Exception) as context:
|
logger = logging.get_logger("transformers.configuration_utils")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
BertModel.from_pretrained(TINY_T5)
|
BertModel.from_pretrained(TINY_T5)
|
||||||
self.assertTrue("You tried to initiate a model of type" in str(context.exception))
|
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||||
|
|||||||
Reference in New Issue
Block a user