Only resize embeddings when necessary (#20043)
* Only resize embeddings when necessary * Add comment
This commit is contained in:
@@ -414,7 +414,13 @@ def main():
|
||||
logger.info("Training new model from scratch")
|
||||
model = AutoModelForTokenClassification.from_config(config)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||
# on a small vocab and want a smaller embedding size, remove this test.
|
||||
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||
if len(tokenizer) > embedding_size:
|
||||
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||
if len(tokenizer) > embedding_size:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Model has labels -> use them.
|
||||
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||
|
||||
Reference in New Issue
Block a user