🚨🚨🚨 Enforce single model initialization (#21431)
* Enforce single model initialization * Add OneFormer example for problem 3 * Do it the Stas way * Actually rename the uses... * Rewrite test * Try to change the test this way * Fix all init slow/fast tests * Break connection * Fix more tests * Fix test for initialization * Remove custom test * Quality * Fix last failing tests * The end?
This commit is contained in:
@@ -69,7 +69,6 @@ from transformers.testing_utils import (
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
@@ -175,6 +174,9 @@ def _config_zero_init(config):
|
||||
for key in configs_no_init.__dict__.keys():
|
||||
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
||||
setattr(configs_no_init, key, 1e-10)
|
||||
if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
|
||||
no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
|
||||
setattr(configs_no_init, key, no_init_subconfig)
|
||||
return configs_no_init
|
||||
|
||||
|
||||
@@ -182,6 +184,31 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
|
||||
|
||||
def _mock_init_weights(self, module):
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
# Use the first letter of the name to get a value and go from a <> -13 to z <> 12
|
||||
value = ord(name[0].lower()) - 110
|
||||
param.data.fill_(value)
|
||||
|
||||
|
||||
def _mock_all_init_weights(self):
|
||||
# Prune heads if needed
|
||||
if self.config.pruned_heads:
|
||||
self.prune_heads(self.config.pruned_heads)
|
||||
|
||||
import transformers.modeling_utils
|
||||
|
||||
if transformers.modeling_utils._init_weights:
|
||||
for module in self.modules():
|
||||
module._is_hf_initialized = False
|
||||
# Initialize weights
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
# Tie weights should be skipped when not initializing all weights
|
||||
# since from_pretrained(...) calls tie weights anyways
|
||||
self.tie_weights()
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelTesterMixin:
|
||||
model_tester = None
|
||||
@@ -357,15 +384,10 @@ class ModelTesterMixin:
|
||||
model.gradient_checkpointing_disable()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
@is_flaky()
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
return
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
@@ -387,7 +409,8 @@ class ModelTesterMixin:
|
||||
|
||||
# make init deterministic, but make sure that
|
||||
# non-initialized weights throw errors nevertheless
|
||||
model_class_copy._init_weights = self._mock_init_weights
|
||||
model_class_copy._init_weights = _mock_init_weights
|
||||
model_class_copy.init_weights = _mock_all_init_weights
|
||||
|
||||
model = base_class(config)
|
||||
state_dict = model.state_dict()
|
||||
@@ -404,13 +427,16 @@ class ModelTesterMixin:
|
||||
|
||||
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
||||
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
# Before we test anything
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
return
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
@@ -432,7 +458,8 @@ class ModelTesterMixin:
|
||||
|
||||
# make init deterministic, but make sure that
|
||||
# non-initialized weights throw errors nevertheless
|
||||
base_class_copy._init_weights = self._mock_init_weights
|
||||
base_class_copy._init_weights = _mock_init_weights
|
||||
base_class_copy.init_weights = _mock_all_init_weights
|
||||
|
||||
model = model_class(config)
|
||||
state_dict = model.state_dict()
|
||||
@@ -454,7 +481,7 @@ class ModelTesterMixin:
|
||||
max_diff = torch.max(
|
||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user