Add mean_resizing for every VLMs' resizing_token_embeddings() (#35717)

* refine all resize_token_embedding()

* ruff format

* hotfix
This commit is contained in:
Gar
2025-02-03 22:03:49 +08:00
committed by GitHub
parent 7eecdf2a86
commit 9d2056f12b
16 changed files with 71 additions and 32 deletions

View File

@@ -73,8 +73,9 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings