Add option to load a pretrained model with mismatched shapes (#12664)

* Add option to load a pretrained model with mismatched shapes

* Fail at loading when mismatched shapes in Flax

* Fix tests

* Update src/transformers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-07-13 10:15:15 -04:00
committed by GitHub
parent 7f6d375029
commit 90178b0cef
9 changed files with 228 additions and 67 deletions

View File

@@ -32,6 +32,7 @@ from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
@@ -40,6 +41,7 @@ from transformers.testing_utils import (
slow,
tooslow,
)
from transformers.utils import logging
if is_tf_available():
@@ -57,6 +59,7 @@ if is_tf_available():
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
tf_top_k_top_p_filtering,
@@ -1308,6 +1311,34 @@ class TFModelTesterMixin:
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
def test_load_with_mismatched_shapes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class not in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
continue
with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class)
_ = model(**inputs)
model.save_pretrained(tmp_dir)
# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(ValueError):
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
logger = logging.get_logger("transformers.modeling_tf_utils")
with CaptureLogger(logger) as cl:
new_model = TFAutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
logits = new_model(**inputs).logits
self.assertEqual(logits.shape[1], 42)
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []