Fix ignore_mismatched_sizes (#14085)
* Fix * Style * Name * Fix tests * Style * Remove embed sizes checking * Disable some tests * Fix * Apply suggestion
This commit is contained in:
committed by
GitHub
parent
e03544a138
commit
234cfefbb0
@@ -98,6 +98,7 @@ class ModelTesterMixin:
|
||||
test_resize_embeddings = True
|
||||
test_resize_position_embeddings = False
|
||||
test_head_masking = True
|
||||
test_mismatched_shapes = True
|
||||
test_missing_keys = True
|
||||
test_model_parallel = False
|
||||
is_encoder_decoder = False
|
||||
@@ -1638,6 +1639,8 @@ class ModelTesterMixin:
|
||||
loss.backward()
|
||||
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
if not self.test_mismatched_shapes:
|
||||
return
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -1650,22 +1653,35 @@ class ModelTesterMixin:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# Fails when we don't set ignore_mismatched_sizes=True
|
||||
with self.assertRaises(RuntimeError) as e:
|
||||
print(type(e))
|
||||
with self.assertRaises(RuntimeError):
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||
with self.assertRaises(RuntimeError):
|
||||
new_model_without_prefix = AutoModel.from_pretrained(tmp_dir, vocab_size=10)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||
)
|
||||
self.assertIn("the shapes did not match", cl.out)
|
||||
|
||||
new_model.to(torch_device)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
logits = new_model(**inputs).logits
|
||||
self.assertEqual(logits.shape[1], 42)
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model_without_prefix = AutoModel.from_pretrained(
|
||||
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
|
||||
)
|
||||
self.assertIn("the shapes did not match", cl.out)
|
||||
input_ids = ids_tensor((2, 8), 10)
|
||||
new_model_without_prefix.to(torch_device)
|
||||
if self.is_encoder_decoder:
|
||||
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
|
||||
else:
|
||||
new_model_without_prefix(input_ids)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user