Replace error by warning when loading an architecture in another (#11207)
* Replace error by warning when loading an architecture in another * Style * Style again * Add a test * Adapt old test
This commit is contained in:
@@ -231,13 +231,7 @@ class BertGenerationEncoderTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
) = config_and_inputs
|
||||
config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -259,6 +253,11 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_as_bert(self):
|
||||
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
|
||||
config.model_type = "bert"
|
||||
self.model_tester.create_and_check_model(config, input_ids, input_mask, token_labels)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user