Add a post init method to all models (#14431)

* Add a post init method to all models

* Fix tests

* Fix last tests

* Fix templates

* Add comment

* Forgot to save
This commit is contained in:
Sylvain Gugger
2021-11-18 08:38:09 -05:00
committed by GitHub
parent 08816de16a
commit d83b0e0c07
70 changed files with 693 additions and 359 deletions

View File

@@ -777,7 +777,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
self.embeddings = {{cookiecutter.camelcase_modelname}}Embeddings(config)
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@@ -943,7 +944,8 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
@@ -1046,7 +1048,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
@@ -1217,7 +1220,8 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.classifier = {{cookiecutter.camelcase_modelname}}ClassificationHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
@@ -1309,7 +1313,8 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel
self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
@@ -1399,7 +1404,8 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
@@ -1486,7 +1492,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
@@ -2224,8 +2231,9 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
@@ -2388,8 +2396,9 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
@@ -2640,7 +2649,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config, self.shared)
self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config, self.shared)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.shared
@@ -2755,7 +2765,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
@@ -3170,7 +3181,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens