Only resize embeddings when necessary (#20043)

* Only resize embeddings when necessary

* Add comment
This commit is contained in:
Sylvain Gugger
2022-11-03 12:05:04 -04:00
committed by GitHub
parent 9080607b2c
commit 06886d5a68
17 changed files with 87 additions and 17 deletions

View File

@@ -414,7 +414,13 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForTokenClassification.from_config(config)
model.resize_token_embeddings(len(tokenizer))
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: