🚨🚨🚨 Enforce single model initialization (#21431)
* Enforce single model initialization * Add OneFormer example for problem 3 * Do it the Stas way * Actually rename the uses... * Rewrite test * Try to change the test this way * Fix all init slow/fast tests * Break connection * Fix more tests * Fix test for initialization * Remove custom test * Quality * Fix last failing tests * The end?
This commit is contained in:
@@ -492,6 +492,48 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
|||||||
The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with
|
The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with
|
||||||
random weights, thus making sure that the `init()` methods of all components works.
|
random weights, thus making sure that the `init()` methods of all components works.
|
||||||
|
|
||||||
|
Note that all random initialization should happen in the `_init_weights` method of your `BrandnewBertPreTrainedModel`
|
||||||
|
class. It should initialize all leaf modules depending on the variables of the config. Here is an example with the
|
||||||
|
BERT `_init_weights` method:
|
||||||
|
|
||||||
|
```py
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can have some more custom schemes if you need a special initialization for some modules. For instance, in
|
||||||
|
`Wav2Vec2ForPreTraining`, the last two linear layers need to have the initialization of the regular PyTorch `nn.Linear`
|
||||||
|
but all the other ones should use an initialization as above. This is coded like this:
|
||||||
|
|
||||||
|
```py
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstnace(module, Wav2Vec2ForPreTraining):
|
||||||
|
module.project_hid.reset_parameters()
|
||||||
|
module.project_q.reset_parameters()
|
||||||
|
module.project_hid._is_hf_initialized = True
|
||||||
|
module.project_q._is_hf_initialized = True
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
```
|
||||||
|
|
||||||
|
The `_is_hf_initialized` flag is internally used to make sure we only initialize a submodule once. By setting it to
|
||||||
|
`True` for `module.project_q` and `module.project_hid`, we make sure the custom initialization we did is not overridden later on,
|
||||||
|
the `_init_weights` function won't be applied to them.
|
||||||
|
|
||||||
**6. Write a conversion script**
|
**6. Write a conversion script**
|
||||||
|
|
||||||
Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in
|
Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in
|
||||||
|
|||||||
@@ -436,6 +436,17 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_initialized_submodules(model, state_dict_keys):
|
||||||
|
"""
|
||||||
|
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
|
||||||
|
dict.
|
||||||
|
"""
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
|
||||||
|
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
|
||||||
|
module._is_hf_initialized = True
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
@@ -1176,7 +1187,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"""
|
"""
|
||||||
Initialize the weights. This method should be overridden by derived class.
|
Initialize the weights. This method should be overridden by derived class.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Make sure `_init_weights` is implemented for {self.__class__}")
|
pass
|
||||||
|
|
||||||
|
def _initialize_weights(self, module):
|
||||||
|
"""
|
||||||
|
Initialize the weights if they are not already initialized.
|
||||||
|
"""
|
||||||
|
if getattr(module, "_is_hf_initialized", False):
|
||||||
|
return
|
||||||
|
self._init_weights(module)
|
||||||
|
module._is_hf_initialized = True
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
"""
|
"""
|
||||||
@@ -1505,7 +1525,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
"""
|
"""
|
||||||
If needed prunes and maybe initializes weights.
|
If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
|
||||||
|
initialization logic in `_init_weights`.
|
||||||
"""
|
"""
|
||||||
# Prune heads if needed
|
# Prune heads if needed
|
||||||
if self.config.pruned_heads:
|
if self.config.pruned_heads:
|
||||||
@@ -1513,7 +1534,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
if _init_weights:
|
if _init_weights:
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
self.apply(self._init_weights)
|
self.apply(self._initialize_weights)
|
||||||
|
|
||||||
# Tie weights should be skipped when not initializing all weights
|
# Tie weights should be skipped when not initializing all weights
|
||||||
# since from_pretrained(...) calls tie weights anyways
|
# since from_pretrained(...) calls tie weights anyways
|
||||||
@@ -2713,11 +2734,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||||
if _fast_init:
|
if _fast_init:
|
||||||
uninitialized_modules = model.retrieve_modules_from_names(
|
if remove_prefix_from_model:
|
||||||
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
|
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
|
||||||
)
|
elif add_prefix_to_model:
|
||||||
for module in uninitialized_modules:
|
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
|
||||||
model._init_weights(module)
|
else:
|
||||||
|
_loaded_keys = loaded_keys
|
||||||
|
set_initialized_submodules(model, _loaded_keys)
|
||||||
|
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||||
|
model.apply(model._initialize_weights)
|
||||||
|
|
||||||
# Set some modules to fp32 if any
|
# Set some modules to fp32 if any
|
||||||
if keep_in_fp32_modules is not None:
|
if keep_in_fp32_modules is not None:
|
||||||
|
|||||||
@@ -1067,10 +1067,12 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
|||||||
module.text_projection.weight,
|
module.text_projection.weight,
|
||||||
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
||||||
)
|
)
|
||||||
|
module.text_projection._is_hf_initialized = True
|
||||||
nn.init.normal_(
|
nn.init.normal_(
|
||||||
module.visual_projection.weight,
|
module.visual_projection.weight,
|
||||||
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
||||||
)
|
)
|
||||||
|
module.visual_projection._is_hf_initialized = True
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
|
|||||||
@@ -1473,8 +1473,9 @@ class BartForSequenceClassification(BartPretrainedModel):
|
|||||||
config.num_labels,
|
config.num_labels,
|
||||||
config.classifier_dropout,
|
config.classifier_dropout,
|
||||||
)
|
)
|
||||||
self.model._init_weights(self.classification_head.dense)
|
|
||||||
self.model._init_weights(self.classification_head.out_proj)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -1601,7 +1602,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
|||||||
self.model = BartModel(config)
|
self.model = BartModel(config)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
self.model._init_weights(self.qa_outputs)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
@@ -2658,8 +2658,9 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
|||||||
config.num_labels,
|
config.num_labels,
|
||||||
config.classifier_dropout,
|
config.classifier_dropout,
|
||||||
)
|
)
|
||||||
self.model._init_weights(self.classification_head.dense)
|
|
||||||
self.model._init_weights(self.classification_head.out_proj)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -2785,7 +2786,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
|||||||
self.model = BigBirdPegasusModel(config)
|
self.model = BigBirdPegasusModel(config)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
self.model._init_weights(self.qa_outputs)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
@@ -1186,6 +1186,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||||||
base_model = FSMTModel(config)
|
base_model = FSMTModel(config)
|
||||||
self.model = base_model
|
self.model = base_model
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
@add_end_docstrings(FSMT_GENERATION_EXAMPLE)
|
@add_end_docstrings(FSMT_GENERATION_EXAMPLE)
|
||||||
|
|||||||
@@ -2543,8 +2543,9 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||||||
config.num_labels,
|
config.num_labels,
|
||||||
config.classifier_dropout,
|
config.classifier_dropout,
|
||||||
)
|
)
|
||||||
self.led._init_weights(self.classification_head.dense)
|
|
||||||
self.led._init_weights(self.classification_head.out_proj)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -2672,7 +2673,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
|||||||
self.led = LEDModel(config)
|
self.led = LEDModel(config)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
self.led._init_weights(self.qa_outputs)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
@@ -866,6 +866,9 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
|||||||
|
|
||||||
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
|
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def channels(self):
|
def channels(self):
|
||||||
return [self.out_feature_channels[name] for name in self.out_features]
|
return [self.out_feature_channels[name] for name in self.out_features]
|
||||||
|
|||||||
@@ -1447,8 +1447,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|||||||
config.num_labels,
|
config.num_labels,
|
||||||
config.classifier_dropout,
|
config.classifier_dropout,
|
||||||
)
|
)
|
||||||
self.model._init_weights(self.classification_head.dense)
|
|
||||||
self.model._init_weights(self.classification_head.out_proj)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -1574,7 +1575,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|||||||
self.model = MBartModel(config)
|
self.model = MBartModel(config)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
self.model._init_weights(self.qa_outputs)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
@@ -1610,8 +1610,8 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
|
|||||||
config.classifier_dropout,
|
config.classifier_dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model._init_weights(self.classification_head.dense)
|
# Initialize weights and apply final processing
|
||||||
self.model._init_weights(self.classification_head.out_proj)
|
self.post_init()
|
||||||
|
|
||||||
def set_lightweight_tuning(self):
|
def set_lightweight_tuning(self):
|
||||||
self.model.set_lightweight_tuning()
|
self.model.set_lightweight_tuning()
|
||||||
@@ -1737,7 +1737,8 @@ class MvpForQuestionAnswering(MvpPreTrainedModel):
|
|||||||
self.model = MvpModel(config)
|
self.model = MvpModel(config)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
self.model._init_weights(self.qa_outputs)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def set_lightweight_tuning(self):
|
def set_lightweight_tuning(self):
|
||||||
self.model.set_lightweight_tuning()
|
self.model.set_lightweight_tuning()
|
||||||
|
|||||||
@@ -2801,6 +2801,7 @@ class OneFormerPreTrainedModel(PreTrainedModel):
|
|||||||
elif isinstance(module, OneFormerTransformerDecoder):
|
elif isinstance(module, OneFormerTransformerDecoder):
|
||||||
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
|
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
|
||||||
nn.init.constant_(module.query_input_projection.bias, 0)
|
nn.init.constant_(module.query_input_projection.bias, 0)
|
||||||
|
module.query_input_projection._is_hf_initialized = True
|
||||||
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
|
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
|
||||||
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
||||||
thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads)
|
thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads)
|
||||||
|
|||||||
@@ -1420,8 +1420,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
|
|||||||
config.num_labels,
|
config.num_labels,
|
||||||
config.classifier_dropout,
|
config.classifier_dropout,
|
||||||
)
|
)
|
||||||
self.model._init_weights(self.classification_head.dense)
|
|
||||||
self.model._init_weights(self.classification_head.out_proj)
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
@@ -301,6 +301,12 @@ class UperNetPreTrainedModel(PreTrainedModel):
|
|||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, UperNetPreTrainedModel):
|
||||||
|
module.backbone.init_weights()
|
||||||
|
module.decode_head.init_weights()
|
||||||
|
module.auxiliary_head.init_weights()
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
self.backbone.init_weights()
|
self.backbone.init_weights()
|
||||||
|
|||||||
@@ -1049,8 +1049,14 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
|
||||||
|
if isinstance(module, Wav2Vec2ForPreTraining):
|
||||||
|
module.project_hid.reset_parameters()
|
||||||
|
module.project_q.reset_parameters()
|
||||||
|
module.project_hid._is_hf_initialized = True
|
||||||
|
module.project_q._is_hf_initialized = True
|
||||||
# gumbel softmax requires special init
|
# gumbel softmax requires special init
|
||||||
if isinstance(module, Wav2Vec2GumbelVectorQuantizer):
|
elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
|
||||||
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
||||||
module.weight_proj.bias.data.zero_()
|
module.weight_proj.bias.data.zero_()
|
||||||
nn.init.uniform_(module.codevectors)
|
nn.init.uniform_(module.codevectors)
|
||||||
@@ -1345,13 +1351,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
|
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
|
||||||
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
||||||
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def set_gumbel_temperature(self, temperature: int):
|
def set_gumbel_temperature(self, temperature: int):
|
||||||
"""
|
"""
|
||||||
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
||||||
|
|||||||
@@ -1089,8 +1089,14 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
|
||||||
|
if isinstance(module, Wav2Vec2ConformerForPreTraining):
|
||||||
|
module.project_hid.reset_parameters()
|
||||||
|
module.project_q.reset_parameters()
|
||||||
|
module.project_hid._is_hf_initialized = True
|
||||||
|
module.project_q._is_hf_initialized = True
|
||||||
# gumbel softmax requires special init
|
# gumbel softmax requires special init
|
||||||
if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
|
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
|
||||||
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
||||||
module.weight_proj.bias.data.zero_()
|
module.weight_proj.bias.data.zero_()
|
||||||
nn.init.uniform_(module.codevectors)
|
nn.init.uniform_(module.codevectors)
|
||||||
@@ -1381,13 +1387,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
|||||||
|
|
||||||
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
|
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
|
||||||
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
||||||
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
|
||||||
def set_gumbel_temperature(self, temperature: int):
|
def set_gumbel_temperature(self, temperature: int):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -962,7 +962,6 @@ class WavLMAdapterLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->WavLM, wav2vec2->wavlm
|
|
||||||
class WavLMPreTrainedModel(PreTrainedModel):
|
class WavLMPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
|||||||
@@ -1496,3 +1496,6 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
|
|||||||
def test_retain_grad_hidden_states_attentions(self):
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
# decoder cannot keep gradients
|
# decoder cannot keep gradients
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -410,17 +410,23 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
configs_no_init = _config_zero_init(config)
|
configs_no_init = _config_zero_init(config)
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
|
# Skip the check for the backbone
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if module.__class__.__name__ == "DetaBackboneWithPositionalEncodings":
|
||||||
|
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
|
||||||
|
break
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if param.requires_grad:
|
if (
|
||||||
if (
|
"level_embed" in name
|
||||||
"level_embed" in name
|
or "sampling_offsets.bias" in name
|
||||||
or "sampling_offsets.bias" in name
|
or "value_proj" in name
|
||||||
or "value_proj" in name
|
or "output_proj" in name
|
||||||
or "output_proj" in name
|
or "reference_points" in name
|
||||||
or "reference_points" in name
|
or name in backbone_params
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
[0.0, 1.0],
|
[0.0, 1.0],
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
|
|||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -242,6 +242,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
# Skip the check for the backbone
|
||||||
|
backbone_params = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
|
||||||
|
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
|
||||||
|
break
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if name in backbone_params:
|
||||||
|
continue
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
|
|||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -256,6 +256,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
# Skip the check for the backbone
|
||||||
|
backbone_params = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
|
||||||
|
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
|
||||||
|
break
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if name in backbone_params:
|
||||||
|
continue
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]:
|
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]:
|
||||||
|
|||||||
@@ -15,9 +15,6 @@
|
|||||||
""" Testing suite for the PyTorch LayoutLMv2 model. """
|
""" Testing suite for the PyTorch LayoutLMv2 model. """
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device
|
from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device
|
||||||
@@ -31,7 +28,6 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_MAPPING,
|
|
||||||
LayoutLMv2Config,
|
LayoutLMv2Config,
|
||||||
LayoutLMv2ForQuestionAnswering,
|
LayoutLMv2ForQuestionAnswering,
|
||||||
LayoutLMv2ForSequenceClassification,
|
LayoutLMv2ForSequenceClassification,
|
||||||
@@ -312,54 +308,6 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
def test_save_load_fast_init_from_base(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
base_class = MODEL_MAPPING[config.__class__]
|
|
||||||
|
|
||||||
if isinstance(base_class, tuple):
|
|
||||||
base_class = base_class[0]
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
if model_class == base_class:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# make a copy of model class to not break future tests
|
|
||||||
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
|
||||||
class CopyClass(model_class):
|
|
||||||
pass
|
|
||||||
|
|
||||||
model_class_copy = CopyClass
|
|
||||||
|
|
||||||
# make sure that all keys are expected for test
|
|
||||||
model_class_copy._keys_to_ignore_on_load_missing = []
|
|
||||||
|
|
||||||
# make init deterministic, but make sure that
|
|
||||||
# non-initialized weights throw errors nevertheless
|
|
||||||
model_class_copy._init_weights = self._mock_init_weights
|
|
||||||
|
|
||||||
model = base_class(config)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
|
|
||||||
# this will often delete a single weight of a multi-weight module
|
|
||||||
# to test an edge case
|
|
||||||
random_key_to_del = random.choice(list(state_dict.keys()))
|
|
||||||
del state_dict[random_key_to_del]
|
|
||||||
|
|
||||||
# check that certain keys didn't get saved with the model
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
|
||||||
|
|
||||||
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
|
||||||
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
|
||||||
|
|
||||||
for key in model_fast_init.state_dict().keys():
|
|
||||||
if key == "layoutlmv2.visual_segment_embedding":
|
|
||||||
# we skip the visual segment embedding as it has a custom initialization scheme
|
|
||||||
continue
|
|
||||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|||||||
@@ -436,10 +436,10 @@ class ProphetNetModelTester:
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
labels=lm_labels,
|
labels=lm_labels,
|
||||||
)
|
)
|
||||||
self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5819, device=torch_device), atol=1e-3))
|
self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5981, device=torch_device), atol=1e-3))
|
||||||
|
|
||||||
expected_logit_slice = torch.tensor(
|
expected_logit_slice = torch.tensor(
|
||||||
[-0.1565, 0.0418, 0.1207, 0.0030, 0.0665, 0.0467, 0.0412], device=torch_device
|
[-0.0648, 0.0790, 0.0360, 0.0089, 0.0039, -0.0639, 0.0131], device=torch_device
|
||||||
)
|
)
|
||||||
self.parent.assertTrue(torch.allclose(result.logits[0, :, 1], expected_logit_slice, atol=1e-3))
|
self.parent.assertTrue(torch.allclose(result.logits[0, :, 1], expected_logit_slice, atol=1e-3))
|
||||||
|
|
||||||
|
|||||||
@@ -1145,10 +1145,11 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
|
||||||
output_slice = hidden_states[1, -1, :5]
|
output_slice = hidden_states[1, -1, :5]
|
||||||
expected_output_slice = torch.tensor(
|
expected_output_slice = torch.tensor(
|
||||||
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393],
|
[0.1018, -0.2026, 0.2116, 0.0270, -0.1233],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||||
|
|
||||||
def test_local_lm_model_grad(self):
|
def test_local_lm_model_grad(self):
|
||||||
@@ -1163,25 +1164,25 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
input_ids, _ = self._get_input_ids_and_mask()
|
input_ids, _ = self._get_input_ids_and_mask()
|
||||||
loss = model(input_ids=input_ids, labels=input_ids)[0]
|
loss = model(input_ids=input_ids, labels=input_ids)[0]
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3))
|
self.assertTrue(torch.allclose(loss, torch.tensor(5.8019, dtype=torch.float, device=torch_device), atol=1e-3))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# check last grads to cover all proable errors
|
# check last grads to cover all proable errors
|
||||||
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
|
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
|
||||||
expected_grad_slice_word = torch.tensor(
|
expected_grad_slice_word = torch.tensor(
|
||||||
[-0.0005, 0.0001, 0.0002, 0.0003, 0.0006],
|
[-0.0005, -0.0001, -0.0002, -0.0006, -0.0006],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
|
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
|
||||||
expected_grad_slice_pos_fac_1 = torch.tensor(
|
expected_grad_slice_pos_fac_1 = torch.tensor(
|
||||||
[0.0037, -1.3793, -1.0231, -1.5230, -2.5306],
|
[-0.5235, 0.5704, 0.0922, -0.3140, 0.9928],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
|
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
|
||||||
expected_grad_slice_pos_fac_2 = torch.tensor(
|
expected_grad_slice_pos_fac_2 = torch.tensor(
|
||||||
[-1.3165, 0.5168, 0.7785, 1.0811, -0.9830],
|
[1.7960, 1.7668, 0.5593, 0.0907, 1.8342],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
@@ -1203,24 +1204,24 @@ class ReformerIntegrationTests(unittest.TestCase):
|
|||||||
input_ids, _ = self._get_input_ids_and_mask()
|
input_ids, _ = self._get_input_ids_and_mask()
|
||||||
loss = model(input_ids=input_ids, labels=input_ids)[0]
|
loss = model(input_ids=input_ids, labels=input_ids)[0]
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3))
|
self.assertTrue(torch.allclose(loss, torch.tensor(5.7854, dtype=torch.float, device=torch_device), atol=1e-3))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# check last grads to cover all proable errors
|
# check last grads to cover all proable errors
|
||||||
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
|
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
|
||||||
expected_grad_slice_word = torch.tensor(
|
expected_grad_slice_word = torch.tensor(
|
||||||
[2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04],
|
[0.0004, 0.0003, 0.0006, -0.0004, 0.0002],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
|
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
|
||||||
expected_grad_slice_pos_fac_1 = torch.tensor(
|
expected_grad_slice_pos_fac_1 = torch.tensor(
|
||||||
[-0.0984, 0.6283, 0.4282, 1.2960, 0.6897],
|
[-0.3792, 0.5593, -1.6993, 0.2033, 0.4131],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
|
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
|
||||||
expected_grad_slice_pos_fac_2 = torch.tensor(
|
expected_grad_slice_pos_fac_2 = torch.tensor(
|
||||||
[0.4626, -0.0231, -0.0172, 0.1081, 0.3805],
|
[-1.4212, -0.3201, -1.1944, 0.1258, 0.2856],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from transformers.testing_utils import require_accelerate, require_torch, requir
|
|||||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -198,6 +198,28 @@ class ViTHybridModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
# Skip the check for the backbone
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if module.__class__.__name__ == "ViTHybridPatchEmbeddings":
|
||||||
|
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
|
||||||
|
break
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if name in backbone_params:
|
||||||
|
continue
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ from transformers.testing_utils import (
|
|||||||
USER,
|
USER,
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
is_flaky,
|
|
||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
@@ -175,6 +174,9 @@ def _config_zero_init(config):
|
|||||||
for key in configs_no_init.__dict__.keys():
|
for key in configs_no_init.__dict__.keys():
|
||||||
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
||||||
setattr(configs_no_init, key, 1e-10)
|
setattr(configs_no_init, key, 1e-10)
|
||||||
|
if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
|
||||||
|
no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
|
||||||
|
setattr(configs_no_init, key, no_init_subconfig)
|
||||||
return configs_no_init
|
return configs_no_init
|
||||||
|
|
||||||
|
|
||||||
@@ -182,6 +184,31 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
|||||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
for name, param in module.named_parameters(recurse=False):
|
||||||
|
# Use the first letter of the name to get a value and go from a <> -13 to z <> 12
|
||||||
|
value = ord(name[0].lower()) - 110
|
||||||
|
param.data.fill_(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_all_init_weights(self):
|
||||||
|
# Prune heads if needed
|
||||||
|
if self.config.pruned_heads:
|
||||||
|
self.prune_heads(self.config.pruned_heads)
|
||||||
|
|
||||||
|
import transformers.modeling_utils
|
||||||
|
|
||||||
|
if transformers.modeling_utils._init_weights:
|
||||||
|
for module in self.modules():
|
||||||
|
module._is_hf_initialized = False
|
||||||
|
# Initialize weights
|
||||||
|
self.apply(self._initialize_weights)
|
||||||
|
|
||||||
|
# Tie weights should be skipped when not initializing all weights
|
||||||
|
# since from_pretrained(...) calls tie weights anyways
|
||||||
|
self.tie_weights()
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ModelTesterMixin:
|
class ModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
@@ -357,15 +384,10 @@ class ModelTesterMixin:
|
|||||||
model.gradient_checkpointing_disable()
|
model.gradient_checkpointing_disable()
|
||||||
self.assertFalse(model.is_gradient_checkpointing)
|
self.assertFalse(model.is_gradient_checkpointing)
|
||||||
|
|
||||||
def _mock_init_weights(self, module):
|
|
||||||
if hasattr(module, "weight") and module.weight is not None:
|
|
||||||
module.weight.data.fill_(3)
|
|
||||||
if hasattr(module, "bias") and module.bias is not None:
|
|
||||||
module.bias.data.fill_(3)
|
|
||||||
|
|
||||||
@is_flaky()
|
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if config.__class__ not in MODEL_MAPPING:
|
||||||
|
return
|
||||||
base_class = MODEL_MAPPING[config.__class__]
|
base_class = MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
if isinstance(base_class, tuple):
|
if isinstance(base_class, tuple):
|
||||||
@@ -387,7 +409,8 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# make init deterministic, but make sure that
|
# make init deterministic, but make sure that
|
||||||
# non-initialized weights throw errors nevertheless
|
# non-initialized weights throw errors nevertheless
|
||||||
model_class_copy._init_weights = self._mock_init_weights
|
model_class_copy._init_weights = _mock_init_weights
|
||||||
|
model_class_copy.init_weights = _mock_all_init_weights
|
||||||
|
|
||||||
model = base_class(config)
|
model = base_class(config)
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
@@ -404,13 +427,16 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
||||||
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||||
|
# Before we test anything
|
||||||
|
|
||||||
for key in model_fast_init.state_dict().keys():
|
for key in model_fast_init.state_dict().keys():
|
||||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||||
|
|
||||||
def test_save_load_fast_init_to_base(self):
|
def test_save_load_fast_init_to_base(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if config.__class__ not in MODEL_MAPPING:
|
||||||
|
return
|
||||||
base_class = MODEL_MAPPING[config.__class__]
|
base_class = MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
if isinstance(base_class, tuple):
|
if isinstance(base_class, tuple):
|
||||||
@@ -432,7 +458,8 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# make init deterministic, but make sure that
|
# make init deterministic, but make sure that
|
||||||
# non-initialized weights throw errors nevertheless
|
# non-initialized weights throw errors nevertheless
|
||||||
base_class_copy._init_weights = self._mock_init_weights
|
base_class_copy._init_weights = _mock_init_weights
|
||||||
|
base_class_copy.init_weights = _mock_all_init_weights
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
@@ -454,7 +481,7 @@ class ModelTesterMixin:
|
|||||||
max_diff = torch.max(
|
max_diff = torch.max(
|
||||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||||
).item()
|
).item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user