Fix ignore_mismatched_sizes (#14085)

* Fix

* Style

* Name

* Fix tests

* Style

* Remove embed sizes checking

* Disable some tests

* Fix

* Apply suggestion
This commit is contained in:
Li-Huai (Allan) Lin
2021-10-22 00:31:29 +08:00
committed by GitHub
parent e03544a138
commit 234cfefbb0
9 changed files with 64 additions and 6 deletions

View File

@@ -59,6 +59,7 @@ if is_tf_available():
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModel,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
@@ -104,6 +105,7 @@ class TFModelTesterMixin:
model_tester = None
all_model_classes = ()
all_generative_model_classes = ()
test_mismatched_shapes = True
test_resize_embeddings = True
test_head_masking = True
is_encoder_decoder = False
@@ -1312,6 +1314,8 @@ class TFModelTesterMixin:
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
def test_load_with_mismatched_shapes(self):
if not self.test_mismatched_shapes:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@@ -1328,6 +1332,8 @@ class TFModelTesterMixin:
# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(ValueError):
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
with self.assertRaises(ValueError):
new_model_without_prefix = TFAutoModel.from_pretrained(tmp_dir, vocab_size=10)
logger = logging.get_logger("transformers.modeling_tf_utils")
with CaptureLogger(logger) as cl:
@@ -1339,6 +1345,20 @@ class TFModelTesterMixin:
logits = new_model(**inputs).logits
self.assertEqual(logits.shape[1], 42)
with CaptureLogger(logger) as cl:
new_model_without_prefix = TFAutoModel.from_pretrained(
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
# Although Tf models always have a prefix pointing to `MainLayer`,
# we still add this "without prefix" test to keep a consistency between tf and pt tests.
input_ids = ids_tensor((2, 8), 10)
if self.is_encoder_decoder:
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
else:
new_model_without_prefix(input_ids)
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []