Add a post init method to all models (#14431)
* Add a post init method to all models * Fix tests * Fix last tests * Fix templates * Add comment * Forgot to save
This commit is contained in:
@@ -777,7 +777,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
self.embeddings = {{cookiecutter.camelcase_modelname}}Embeddings(config)
|
||||
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
@@ -943,7 +944,8 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
|
||||
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
|
||||
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
@@ -1046,7 +1048,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
|
||||
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
@@ -1217,7 +1220,8 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
|
||||
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
|
||||
self.classifier = {{cookiecutter.camelcase_modelname}}ClassificationHead(config)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@@ -1309,7 +1313,8 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@@ -1399,7 +1404,8 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@@ -1486,7 +1492,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
||||
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@@ -2224,8 +2231,9 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
||||
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.init_weights()
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -2388,8 +2396,9 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.init_weights()
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
@@ -2640,7 +2649,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config, self.shared)
|
||||
self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config, self.shared)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
@@ -2755,7 +2765,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
||||
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.get_encoder()
|
||||
@@ -3170,7 +3181,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.decoder.embed_tokens
|
||||
|
||||
Reference in New Issue
Block a user