reszie token embeds (#11524)

This commit is contained in:
Suraj Patil
2021-04-30 18:17:01 +05:30
committed by GitHub
parent 20d6931e32
commit 57c8e822f7
2 changed files with 4 additions and 0 deletions

View File

@@ -353,6 +353,8 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

View File

@@ -337,6 +337,8 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
model.resize_token_embeddings(len(tokenizer))
# Set decoder_start_token_id # Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer): if isinstance(tokenizer, MBartTokenizer):