🚨🚨🚨 Enforce single model initialization (#21431)
* Enforce single model initialization * Add OneFormer example for problem 3 * Do it the Stas way * Actually rename the uses... * Rewrite test * Try to change the test this way * Fix all init slow/fast tests * Break connection * Fix more tests * Fix test for initialization * Remove custom test * Quality * Fix last failing tests * The end?
This commit is contained in:
@@ -15,9 +15,6 @@
|
||||
""" Testing suite for the PyTorch LayoutLMv2 model. """
|
||||
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device
|
||||
@@ -31,7 +28,6 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_MAPPING,
|
||||
LayoutLMv2Config,
|
||||
LayoutLMv2ForQuestionAnswering,
|
||||
LayoutLMv2ForSequenceClassification,
|
||||
@@ -312,54 +308,6 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
# make a copy of model class to not break future tests
|
||||
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||
class CopyClass(model_class):
|
||||
pass
|
||||
|
||||
model_class_copy = CopyClass
|
||||
|
||||
# make sure that all keys are expected for test
|
||||
model_class_copy._keys_to_ignore_on_load_missing = []
|
||||
|
||||
# make init deterministic, but make sure that
|
||||
# non-initialized weights throw errors nevertheless
|
||||
model_class_copy._init_weights = self._mock_init_weights
|
||||
|
||||
model = base_class(config)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# this will often delete a single weight of a multi-weight module
|
||||
# to test an edge case
|
||||
random_key_to_del = random.choice(list(state_dict.keys()))
|
||||
del state_dict[random_key_to_del]
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||
|
||||
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
||||
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
if key == "layoutlmv2.visual_segment_embedding":
|
||||
# we skip the visual segment embedding as it has a custom initialization scheme
|
||||
continue
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
Reference in New Issue
Block a user