Replace error by warning when loading an architecture in another (#11207)

* Replace error by warning when loading an architecture in another

* Style

* Style again

* Add a test

* Adapt old test
This commit is contained in:
Sylvain Gugger
2021-04-13 10:33:52 -04:00
committed by GitHub
parent 22fa0a6004
commit 81009b7a5c
3 changed files with 16 additions and 15 deletions

View File

@@ -231,13 +231,7 @@ class BertGenerationEncoderTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
token_labels,
) = config_and_inputs
config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@@ -259,6 +253,11 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_as_bert(self):
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
config.model_type = "bert"
self.model_tester.create_and_check_model(config, input_ids, input_mask, token_labels)
def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)

View File

@@ -22,10 +22,10 @@ import tempfile
import unittest
from typing import List, Tuple
from transformers import is_torch_available
from transformers import is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_multi_gpu, slow, torch_device
if is_torch_available():
@@ -1296,6 +1296,7 @@ class ModelUtilsTest(unittest.TestCase):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertIsNotNone(model)
with self.assertRaises(Exception) as context:
logger = logging.get_logger("transformers.configuration_utils")
with CaptureLogger(logger) as cl:
BertModel.from_pretrained(TINY_T5)
self.assertTrue("You tried to initiate a model of type" in str(context.exception))
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)