🚨🚨🚨 Enforce single model initialization (#21431)

* Enforce single model initialization

* Add OneFormer example for problem 3

* Do it the Stas way

* Actually rename the uses...

* Rewrite test

* Try to change the test this way

* Fix all init slow/fast tests

* Break connection

* Fix more tests

* Fix test for initialization

* Remove custom test

* Quality

* Fix last failing tests

* The end?
This commit is contained in:
Sylvain Gugger
2023-02-09 15:46:26 -05:00
committed by GitHub
parent 2020ac4bd6
commit 04b2f13c37
25 changed files with 277 additions and 123 deletions

View File

@@ -492,6 +492,48 @@ model = BrandNewBertModel(BrandNewBertConfig())
The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with
random weights, thus making sure that the `init()` methods of all components works. random weights, thus making sure that the `init()` methods of all components works.
Note that all random initialization should happen in the `_init_weights` method of your `BrandnewBertPreTrainedModel`
class. It should initialize all leaf modules depending on the variables of the config. Here is an example with the
BERT `_init_weights` method:
```py
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
```
You can have some more custom schemes if you need a special initialization for some modules. For instance, in
`Wav2Vec2ForPreTraining`, the last two linear layers need to have the initialization of the regular PyTorch `nn.Linear`
but all the other ones should use an initialization as above. This is coded like this:
```py
def _init_weights(self, module):
"""Initialize the weights"""
if isinstnace(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
```
The `_is_hf_initialized` flag is internally used to make sure we only initialize a submodule once. By setting it to
`True` for `module.project_q` and `module.project_hid`, we make sure the custom initialization we did is not overridden later on,
the `_init_weights` function won't be applied to them.
**6. Write a conversion script** **6. Write a conversion script**
Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in

View File

@@ -436,6 +436,17 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
) )
def set_initialized_submodules(model, state_dict_keys):
"""
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
for module_name, module in model.named_modules():
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = True
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
@@ -1176,7 +1187,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
""" """
Initialize the weights. This method should be overridden by derived class. Initialize the weights. This method should be overridden by derived class.
""" """
raise NotImplementedError(f"Make sure `_init_weights` is implemented for {self.__class__}") pass
def _initialize_weights(self, module):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
module._is_hf_initialized = True
def tie_weights(self): def tie_weights(self):
""" """
@@ -1505,7 +1525,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def init_weights(self): def init_weights(self):
""" """
If needed prunes and maybe initializes weights. If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
initialization logic in `_init_weights`.
""" """
# Prune heads if needed # Prune heads if needed
if self.config.pruned_heads: if self.config.pruned_heads:
@@ -1513,7 +1534,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if _init_weights: if _init_weights:
# Initialize weights # Initialize weights
self.apply(self._init_weights) self.apply(self._initialize_weights)
# Tie weights should be skipped when not initializing all weights # Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways # since from_pretrained(...) calls tie weights anyways
@@ -2713,11 +2734,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init: if _fast_init:
uninitialized_modules = model.retrieve_modules_from_names( if remove_prefix_from_model:
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
) elif add_prefix_to_model:
for module in uninitialized_modules: _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
model._init_weights(module) else:
_loaded_keys = loaded_keys
set_initialized_submodules(model, _loaded_keys)
# This will only initialize submodules that are not marked as initialized by the line above.
model.apply(model._initialize_weights)
# Set some modules to fp32 if any # Set some modules to fp32 if any
if keep_in_fp32_modules is not None: if keep_in_fp32_modules is not None:

View File

@@ -1067,10 +1067,12 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
module.text_projection.weight, module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor, std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
) )
module.text_projection._is_hf_initialized = True
nn.init.normal_( nn.init.normal_(
module.visual_projection.weight, module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
) )
module.visual_projection._is_hf_initialized = True
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)

View File

@@ -1473,8 +1473,9 @@ class BartForSequenceClassification(BartPretrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -1601,7 +1602,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
self.model = BartModel(config) self.model = BartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(

View File

@@ -2658,8 +2658,9 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -2785,7 +2786,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
self.model = BigBirdPegasusModel(config) self.model = BigBirdPegasusModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(

View File

@@ -1186,6 +1186,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model = FSMTModel(config) base_model = FSMTModel(config)
self.model = base_model self.model = base_model
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(FSMT_GENERATION_EXAMPLE) @add_end_docstrings(FSMT_GENERATION_EXAMPLE)

View File

@@ -2543,8 +2543,9 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.led._init_weights(self.classification_head.dense)
self.led._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -2672,7 +2673,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
self.led = LEDModel(config) self.led = LEDModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.led._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(

View File

@@ -866,6 +866,9 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels]) self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
# Initialize weights and apply final processing
self.post_init()
@property @property
def channels(self): def channels(self):
return [self.out_feature_channels[name] for name in self.out_features] return [self.out_feature_channels[name] for name in self.out_features]

View File

@@ -1447,8 +1447,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -1574,7 +1575,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
self.model = MBartModel(config) self.model = MBartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(

View File

@@ -1610,8 +1610,8 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense) # Initialize weights and apply final processing
self.model._init_weights(self.classification_head.out_proj) self.post_init()
def set_lightweight_tuning(self): def set_lightweight_tuning(self):
self.model.set_lightweight_tuning() self.model.set_lightweight_tuning()
@@ -1737,7 +1737,8 @@ class MvpForQuestionAnswering(MvpPreTrainedModel):
self.model = MvpModel(config) self.model = MvpModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs) # Initialize weights and apply final processing
self.post_init()
def set_lightweight_tuning(self): def set_lightweight_tuning(self):
self.model.set_lightweight_tuning() self.model.set_lightweight_tuning()

View File

@@ -2801,6 +2801,7 @@ class OneFormerPreTrainedModel(PreTrainedModel):
elif isinstance(module, OneFormerTransformerDecoder): elif isinstance(module, OneFormerTransformerDecoder):
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
nn.init.constant_(module.query_input_projection.bias, 0) nn.init.constant_(module.query_input_projection.bias, 0)
module.query_input_projection._is_hf_initialized = True
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0) nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads)

View File

@@ -1420,8 +1420,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
config.num_labels, config.num_labels,
config.classifier_dropout, config.classifier_dropout,
) )
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(

View File

@@ -301,6 +301,12 @@ class UperNetPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, UperNetPreTrainedModel):
module.backbone.init_weights()
module.decode_head.init_weights()
module.auxiliary_head.init_weights()
def init_weights(self): def init_weights(self):
"""Initialize the weights""" """Initialize the weights"""
self.backbone.init_weights() self.backbone.init_weights()

View File

@@ -1049,8 +1049,14 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
if isinstance(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
# gumbel softmax requires special init # gumbel softmax requires special init
if isinstance(module, Wav2Vec2GumbelVectorQuantizer): elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_() module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors) nn.init.uniform_(module.codevectors)
@@ -1345,13 +1351,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config) self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int): def set_gumbel_temperature(self, temperature: int):
""" """
Set the Gumbel softmax temperature to a given value. Only necessary for training Set the Gumbel softmax temperature to a given value. Only necessary for training

View File

@@ -1089,8 +1089,14 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
if isinstance(module, Wav2Vec2ConformerForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
# gumbel softmax requires special init # gumbel softmax requires special init
if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_() module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors) nn.init.uniform_(module.codevectors)
@@ -1381,13 +1387,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config) self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
def set_gumbel_temperature(self, temperature: int): def set_gumbel_temperature(self, temperature: int):
""" """

View File

@@ -962,7 +962,6 @@ class WavLMAdapterLayer(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->WavLM, wav2vec2->wavlm
class WavLMPreTrainedModel(PreTrainedModel): class WavLMPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained

View File

@@ -1496,3 +1496,6 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients # decoder cannot keep gradients
return return
def test_save_load_fast_init_from_base(self):
pass

View File

@@ -410,17 +410,23 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
# Skip the check for the backbone
for name, module in model.named_modules():
if module.__class__.__name__ == "DetaBackboneWithPositionalEncodings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
if param.requires_grad: if (
if ( "level_embed" in name
"level_embed" in name or "sampling_offsets.bias" in name
or "sampling_offsets.bias" in name or "value_proj" in name
or "value_proj" in name or "output_proj" in name
or "output_proj" in name or "reference_points" in name
or "reference_points" in name or name in backbone_params
): ):
continue continue
self.assertIn( self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(), ((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0], [0.0, 1.0],

View File

@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
@@ -242,6 +242,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
backbone_params = []
for name, module in model.named_modules():
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

View File

@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
@@ -256,6 +256,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
backbone_params = []
for name, module in model.named_modules():
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]: for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]:

View File

@@ -15,9 +15,6 @@
""" Testing suite for the PyTorch LayoutLMv2 model. """ """ Testing suite for the PyTorch LayoutLMv2 model. """
import os
import random
import tempfile
import unittest import unittest
from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device
@@ -31,7 +28,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
MODEL_MAPPING,
LayoutLMv2Config, LayoutLMv2Config,
LayoutLMv2ForQuestionAnswering, LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForSequenceClassification, LayoutLMv2ForSequenceClassification,
@@ -312,54 +308,6 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(model_class):
pass
model_class_copy = CopyClass
# make sure that all keys are expected for test
model_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights
model = base_class(config)
state_dict = model.state_dict()
# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
if key == "layoutlmv2.visual_segment_embedding":
# we skip the visual segment embedding as it has a custom initialization scheme
continue
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True

View File

@@ -436,10 +436,10 @@ class ProphetNetModelTester:
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=lm_labels, labels=lm_labels,
) )
self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5819, device=torch_device), atol=1e-3)) self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5981, device=torch_device), atol=1e-3))
expected_logit_slice = torch.tensor( expected_logit_slice = torch.tensor(
[-0.1565, 0.0418, 0.1207, 0.0030, 0.0665, 0.0467, 0.0412], device=torch_device [-0.0648, 0.0790, 0.0360, 0.0089, 0.0039, -0.0639, 0.0131], device=torch_device
) )
self.parent.assertTrue(torch.allclose(result.logits[0, :, 1], expected_logit_slice, atol=1e-3)) self.parent.assertTrue(torch.allclose(result.logits[0, :, 1], expected_logit_slice, atol=1e-3))

View File

@@ -1145,10 +1145,11 @@ class ReformerIntegrationTests(unittest.TestCase):
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[1, -1, :5] output_slice = hidden_states[1, -1, :5]
expected_output_slice = torch.tensor( expected_output_slice = torch.tensor(
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393], [0.1018, -0.2026, 0.2116, 0.0270, -0.1233],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_lm_model_grad(self): def test_local_lm_model_grad(self):
@@ -1163,25 +1164,25 @@ class ReformerIntegrationTests(unittest.TestCase):
input_ids, _ = self._get_input_ids_and_mask() input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0] loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3)) self.assertTrue(torch.allclose(loss, torch.tensor(5.8019, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward() loss.backward()
# check last grads to cover all proable errors # check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor( expected_grad_slice_word = torch.tensor(
[-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], [-0.0005, -0.0001, -0.0002, -0.0006, -0.0006],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor( expected_grad_slice_pos_fac_1 = torch.tensor(
[0.0037, -1.3793, -1.0231, -1.5230, -2.5306], [-0.5235, 0.5704, 0.0922, -0.3140, 0.9928],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor( expected_grad_slice_pos_fac_2 = torch.tensor(
[-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], [1.7960, 1.7668, 0.5593, 0.0907, 1.8342],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
@@ -1203,24 +1204,24 @@ class ReformerIntegrationTests(unittest.TestCase):
input_ids, _ = self._get_input_ids_and_mask() input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0] loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3)) self.assertTrue(torch.allclose(loss, torch.tensor(5.7854, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward() loss.backward()
# check last grads to cover all proable errors # check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor( expected_grad_slice_word = torch.tensor(
[2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], [0.0004, 0.0003, 0.0006, -0.0004, 0.0002],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor( expected_grad_slice_pos_fac_1 = torch.tensor(
[-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], [-0.3792, 0.5593, -1.6993, 0.2033, 0.4131],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor( expected_grad_slice_pos_fac_2 = torch.tensor(
[0.4626, -0.0231, -0.0172, 0.1081, 0.3805], [-1.4212, -0.3201, -1.1944, 0.1258, 0.2856],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )

View File

@@ -23,7 +23,7 @@ from transformers.testing_utils import require_accelerate, require_torch, requir
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
@@ -198,6 +198,28 @@ class ViTHybridModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
for name, module in model.named_modules():
if module.__class__.__name__ == "ViTHybridPatchEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

View File

@@ -69,7 +69,6 @@ from transformers.testing_utils import (
USER, USER,
CaptureLogger, CaptureLogger,
TestCasePlus, TestCasePlus,
is_flaky,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test, is_staging_test,
@@ -175,6 +174,9 @@ def _config_zero_init(config):
for key in configs_no_init.__dict__.keys(): for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10) setattr(configs_no_init, key, 1e-10)
if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
setattr(configs_no_init, key, no_init_subconfig)
return configs_no_init return configs_no_init
@@ -182,6 +184,31 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
def _mock_init_weights(self, module):
for name, param in module.named_parameters(recurse=False):
# Use the first letter of the name to get a value and go from a <> -13 to z <> 12
value = ord(name[0].lower()) - 110
param.data.fill_(value)
def _mock_all_init_weights(self):
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
import transformers.modeling_utils
if transformers.modeling_utils._init_weights:
for module in self.modules():
module._is_hf_initialized = False
# Initialize weights
self.apply(self._initialize_weights)
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
self.tie_weights()
@require_torch @require_torch
class ModelTesterMixin: class ModelTesterMixin:
model_tester = None model_tester = None
@@ -357,15 +384,10 @@ class ModelTesterMixin:
model.gradient_checkpointing_disable() model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing) self.assertFalse(model.is_gradient_checkpointing)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@is_flaky()
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__] base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple): if isinstance(base_class, tuple):
@@ -387,7 +409,8 @@ class ModelTesterMixin:
# make init deterministic, but make sure that # make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless # non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights model_class_copy._init_weights = _mock_init_weights
model_class_copy.init_weights = _mock_all_init_weights
model = base_class(config) model = base_class(config)
state_dict = model.state_dict() state_dict = model.state_dict()
@@ -404,13 +427,16 @@ class ModelTesterMixin:
model_fast_init = model_class_copy.from_pretrained(tmpdirname) model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False) model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
# Before we test anything
for key in model_fast_init.state_dict().keys(): for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__] base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple): if isinstance(base_class, tuple):
@@ -432,7 +458,8 @@ class ModelTesterMixin:
# make init deterministic, but make sure that # make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless # non-initialized weights throw errors nevertheless
base_class_copy._init_weights = self._mock_init_weights base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights
model = model_class(config) model = model_class(config)
state_dict = model.state_dict() state_dict = model.state_dict()
@@ -454,7 +481,7 @@ class ModelTesterMixin:
max_diff = torch.max( max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item() ).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()