Only resize embeddings when necessary (#20043)

* Only resize embeddings when necessary

* Add comment
This commit is contained in:
Sylvain Gugger
2022-11-03 12:05:04 -04:00
committed by GitHub
parent 9080607b2c
commit 06886d5a68
17 changed files with 87 additions and 17 deletions

View File

@@ -422,7 +422,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
model.resize_token_embeddings(len(tokenizer))
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):