Fix ignore_mismatched_sizes (#14085)
* Fix * Style * Name * Fix tests * Style * Remove embed sizes checking * Disable some tests * Fix * Apply suggestion
This commit is contained in:
committed by
GitHub
parent
e03544a138
commit
234cfefbb0
@@ -1512,10 +1512,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if ignore_mismatched_sizes:
|
if ignore_mismatched_sizes:
|
||||||
for checkpoint_key in loaded_keys:
|
for checkpoint_key in loaded_keys:
|
||||||
model_key = checkpoint_key
|
model_key = checkpoint_key
|
||||||
if remove_prefix and checkpoint_key.startswith(prefix):
|
if remove_prefix:
|
||||||
model_key = ".".join(checkpoint_key.split(".")[1:])
|
|
||||||
elif add_prefix:
|
|
||||||
model_key = f"{prefix}.{checkpoint_key}"
|
model_key = f"{prefix}.{checkpoint_key}"
|
||||||
|
elif add_prefix:
|
||||||
|
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_key in model_state_dict
|
model_key in model_state_dict
|
||||||
|
|||||||
@@ -220,6 +220,7 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
test_mismatched_shapes = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class ModelTesterMixin:
|
|||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_resize_position_embeddings = False
|
test_resize_position_embeddings = False
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
|
test_mismatched_shapes = True
|
||||||
test_missing_keys = True
|
test_missing_keys = True
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
@@ -1638,6 +1639,8 @@ class ModelTesterMixin:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def test_load_with_mismatched_shapes(self):
|
def test_load_with_mismatched_shapes(self):
|
||||||
|
if not self.test_mismatched_shapes:
|
||||||
|
return
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -1650,22 +1653,35 @@ class ModelTesterMixin:
|
|||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
# Fails when we don't set ignore_mismatched_sizes=True
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
with self.assertRaises(RuntimeError) as e:
|
with self.assertRaises(RuntimeError):
|
||||||
print(type(e))
|
|
||||||
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
new_model_without_prefix = AutoModel.from_pretrained(tmp_dir, vocab_size=10)
|
||||||
|
|
||||||
logger = logging.get_logger("transformers.modeling_utils")
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
new_model = AutoModelForSequenceClassification.from_pretrained(
|
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||||
)
|
)
|
||||||
self.assertIn("the shapes did not match", cl.out)
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
|
||||||
new_model.to(torch_device)
|
new_model.to(torch_device)
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
logits = new_model(**inputs).logits
|
logits = new_model(**inputs).logits
|
||||||
self.assertEqual(logits.shape[1], 42)
|
self.assertEqual(logits.shape[1], 42)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model_without_prefix = AutoModel.from_pretrained(
|
||||||
|
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
input_ids = ids_tensor((2, 8), 10)
|
||||||
|
new_model_without_prefix.to(torch_device)
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
|
||||||
|
else:
|
||||||
|
new_model_without_prefix(input_ids)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_attn_probs = False
|
test_attn_probs = False
|
||||||
|
test_mismatched_shapes = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = FlaxBigBirdModelTester(self)
|
self.model_tester = FlaxBigBirdModelTester(self)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ if is_flax_available():
|
|||||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
FLAX_MODEL_MAPPING,
|
FLAX_MODEL_MAPPING,
|
||||||
|
FlaxAutoModel,
|
||||||
FlaxAutoModelForSequenceClassification,
|
FlaxAutoModelForSequenceClassification,
|
||||||
FlaxBertModel,
|
FlaxBertModel,
|
||||||
)
|
)
|
||||||
@@ -116,6 +117,7 @@ def random_attention_mask(shape, rng=None):
|
|||||||
class FlaxModelTesterMixin:
|
class FlaxModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
|
test_mismatched_shapes = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class):
|
def _prepare_for_class(self, inputs_dict, model_class):
|
||||||
@@ -579,6 +581,8 @@ class FlaxModelTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_load_with_mismatched_shapes(self):
|
def test_load_with_mismatched_shapes(self):
|
||||||
|
if not self.test_mismatched_shapes:
|
||||||
|
return
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -593,6 +597,8 @@ class FlaxModelTesterMixin:
|
|||||||
# Fails when we don't set ignore_mismatched_sizes=True
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
new_model_without_prefix = FlaxAutoModel.from_pretrained(tmp_dir, vocab_size=10)
|
||||||
|
|
||||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
@@ -604,6 +610,17 @@ class FlaxModelTesterMixin:
|
|||||||
logits = new_model(**inputs_dict)["logits"]
|
logits = new_model(**inputs_dict)["logits"]
|
||||||
self.assertEqual(logits.shape[1], 42)
|
self.assertEqual(logits.shape[1], 42)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model_without_prefix = FlaxAutoModel.from_pretrained(
|
||||||
|
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
input_ids = ids_tensor((2, 8), 10)
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
|
||||||
|
else:
|
||||||
|
new_model_without_prefix(input_ids)
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
test_mismatched_shapes = False
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
TFAutoModel,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
@@ -104,6 +105,7 @@ class TFModelTesterMixin:
|
|||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
|
test_mismatched_shapes = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
@@ -1312,6 +1314,8 @@ class TFModelTesterMixin:
|
|||||||
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
|
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
|
||||||
|
|
||||||
def test_load_with_mismatched_shapes(self):
|
def test_load_with_mismatched_shapes(self):
|
||||||
|
if not self.test_mismatched_shapes:
|
||||||
|
return
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -1328,6 +1332,8 @@ class TFModelTesterMixin:
|
|||||||
# Fails when we don't set ignore_mismatched_sizes=True
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
new_model_without_prefix = TFAutoModel.from_pretrained(tmp_dir, vocab_size=10)
|
||||||
|
|
||||||
logger = logging.get_logger("transformers.modeling_tf_utils")
|
logger = logging.get_logger("transformers.modeling_tf_utils")
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
@@ -1339,6 +1345,20 @@ class TFModelTesterMixin:
|
|||||||
logits = new_model(**inputs).logits
|
logits = new_model(**inputs).logits
|
||||||
self.assertEqual(logits.shape[1], 42)
|
self.assertEqual(logits.shape[1], 42)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model_without_prefix = TFAutoModel.from_pretrained(
|
||||||
|
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
|
||||||
|
# Although Tf models always have a prefix pointing to `MainLayer`,
|
||||||
|
# we still add this "without prefix" test to keep a consistency between tf and pt tests.
|
||||||
|
input_ids = ids_tensor((2, 8), 10)
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
|
||||||
|
else:
|
||||||
|
new_model_without_prefix(input_ids)
|
||||||
|
|
||||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
# special tokens cannot be bad tokens
|
# special tokens cannot be bad tokens
|
||||||
special_tokens = []
|
special_tokens = []
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_onnx = False
|
test_onnx = False
|
||||||
|
test_mismatched_shapes = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFTransfoXLModelTester(self)
|
self.model_tester = TFTransfoXLModelTester(self)
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
|
test_mismatched_shapes = False
|
||||||
|
|
||||||
def check_cutoffs_and_n_token(
|
def check_cutoffs_and_n_token(
|
||||||
self, copied_cutoffs, layer, model_embed, model, model_class, resized_value, vocab_size
|
self, copied_cutoffs, layer, model_embed, model, model_class, resized_value, vocab_size
|
||||||
|
|||||||
Reference in New Issue
Block a user