Add option to load a pretrained model with mismatched shapes (#12664)
* Add option to load a pretrained model with mismatched shapes * Fail at loading when mismatched shapes in Flax * Fix tests * Update src/transformers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Address review comments Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoModel, is_torch_available, logging
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
@@ -1532,6 +1532,35 @@ class ModelTesterMixin:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
continue
|
||||
|
||||
with self.subTest(msg=f"Testing {model_class}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(config)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# Fails when we don't set ignore_mismatched_sizes=True
|
||||
with self.assertRaises(RuntimeError) as e:
|
||||
print(type(e))
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||
)
|
||||
self.assertIn("the shapes did not match", cl.out)
|
||||
|
||||
new_model.to(torch_device)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
logits = new_model(**inputs).logits
|
||||
self.assertEqual(logits.shape[1], 42)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@@ -24,17 +24,19 @@ import numpy as np
|
||||
import transformers
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available
|
||||
from transformers import BertConfig, is_flax_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -45,7 +47,13 @@ if is_flax_available():
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING
|
||||
from transformers import (
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
@@ -516,6 +524,32 @@ class FlaxModelTesterMixin:
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class not in get_values(FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
continue
|
||||
|
||||
with self.subTest(msg=f"Testing {model_class}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(config)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# Fails when we don't set ignore_mismatched_sizes=True
|
||||
with self.assertRaises(ValueError):
|
||||
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(
|
||||
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||
)
|
||||
self.assertIn("the shapes did not match", cl.out)
|
||||
|
||||
logits = new_model(**inputs_dict)["logits"]
|
||||
self.assertEqual(logits.shape[1], 42)
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
||||
@@ -32,6 +32,7 @@ from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
@@ -40,6 +41,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
tooslow,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@@ -57,6 +59,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFSharedEmbeddings,
|
||||
tf_top_k_top_p_filtering,
|
||||
@@ -1308,6 +1311,34 @@ class TFModelTesterMixin:
|
||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
|
||||
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class not in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
continue
|
||||
|
||||
with self.subTest(msg=f"Testing {model_class}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(config)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
_ = model(**inputs)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# Fails when we don't set ignore_mismatched_sizes=True
|
||||
with self.assertRaises(ValueError):
|
||||
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_tf_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = TFAutoModelForSequenceClassification.from_pretrained(
|
||||
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||
)
|
||||
self.assertIn("the shapes did not match", cl.out)
|
||||
|
||||
logits = new_model(**inputs).logits
|
||||
self.assertEqual(logits.shape[1], 42)
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = []
|
||||
|
||||
Reference in New Issue
Block a user