Refactor embedding input/output getter/setter (#39339)

* simplify common get/set

* remove some noise

* change some 5 years old modeling utils

* update examples

* fix copies

* revert some changes

* fixes, gah

* format

* move to Mixin

* remove smolvlm specific require grad

* skip

* force defaults

* remodularise some stuff

* remodularise more stuff

* add safety for audio models

* style

* have a correct fallback, you daft donkey

* remove this argh

* change heuristic for audio models

* fixup

* revert

* this works

* revert again

* 🧠

* aaah ESM has two modelings aaah

* add informative but short comment

* add `input_embed_layer` mixin attribute

* style

* walrus has low precedence

* modular fix

* this was breaking parser
This commit is contained in:
Pablo Montalvo
2025-07-21 18:18:14 +02:00
committed by GitHub
parent 2da97f0943
commit 69b158260f
163 changed files with 235 additions and 2388 deletions

View File

@@ -333,12 +333,6 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@check_model_inputs
@auto_docstring
def forward(
@@ -433,12 +427,6 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(

View File

@@ -389,12 +389,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.set_decoder(decoder)

View File

@@ -332,12 +332,6 @@ class SuperModel(SuperPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@check_model_inputs
@auto_docstring
def forward(