🚨🚨🚨 Enforce single model initialization (#21431)
* Enforce single model initialization * Add OneFormer example for problem 3 * Do it the Stas way * Actually rename the uses... * Rewrite test * Try to change the test this way * Fix all init slow/fast tests * Break connection * Fix more tests * Fix test for initialization * Remove custom test * Quality * Fix last failing tests * The end?
This commit is contained in:
@@ -492,6 +492,48 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with
|
||||
random weights, thus making sure that the `init()` methods of all components works.
|
||||
|
||||
Note that all random initialization should happen in the `_init_weights` method of your `BrandnewBertPreTrainedModel`
|
||||
class. It should initialize all leaf modules depending on the variables of the config. Here is an example with the
|
||||
BERT `_init_weights` method:
|
||||
|
||||
```py
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
```
|
||||
|
||||
You can have some more custom schemes if you need a special initialization for some modules. For instance, in
|
||||
`Wav2Vec2ForPreTraining`, the last two linear layers need to have the initialization of the regular PyTorch `nn.Linear`
|
||||
but all the other ones should use an initialization as above. This is coded like this:
|
||||
|
||||
```py
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstnace(module, Wav2Vec2ForPreTraining):
|
||||
module.project_hid.reset_parameters()
|
||||
module.project_q.reset_parameters()
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
```
|
||||
|
||||
The `_is_hf_initialized` flag is internally used to make sure we only initialize a submodule once. By setting it to
|
||||
`True` for `module.project_q` and `module.project_hid`, we make sure the custom initialization we did is not overridden later on,
|
||||
the `_init_weights` function won't be applied to them.
|
||||
|
||||
**6. Write a conversion script**
|
||||
|
||||
Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in
|
||||
|
||||
Reference in New Issue
Block a user