🚨🚨🚨 Enforce single model initialization (#21431)

* Enforce single model initialization

* Add OneFormer example for problem 3

* Do it the Stas way

* Actually rename the uses...

* Rewrite test

* Try to change the test this way

* Fix all init slow/fast tests

* Break connection

* Fix more tests

* Fix test for initialization

* Remove custom test

* Quality

* Fix last failing tests

* The end?
This commit is contained in:
Sylvain Gugger
2023-02-09 15:46:26 -05:00
committed by GitHub
parent 2020ac4bd6
commit 04b2f13c37
25 changed files with 277 additions and 123 deletions

View File

@@ -15,9 +15,6 @@
""" Testing suite for the PyTorch LayoutLMv2 model. """
import os
import random
import tempfile
import unittest
from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device
@@ -31,7 +28,6 @@ if is_torch_available():
import torch
from transformers import (
MODEL_MAPPING,
LayoutLMv2Config,
LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForSequenceClassification,
@@ -312,54 +308,6 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(model_class):
pass
model_class_copy = CopyClass
# make sure that all keys are expected for test
model_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights
model = base_class(config)
state_dict = model.state_dict()
# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
if key == "layoutlmv2.visual_segment_embedding":
# we skip the visual segment embedding as it has a custom initialization scheme
continue
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True