From 57c8e822f7faa1c19f9926338f21f3aab2269997 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 30 Apr 2021 18:17:01 +0530 Subject: [PATCH] reszie token embeds (#11524) --- examples/pytorch/summarization/run_summarization.py | 2 ++ examples/pytorch/translation/run_translation.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 05291a85fe..c310cbd4f4 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -353,6 +353,8 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + model.resize_token_embeddings(len(tokenizer)) + if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 125ab70710..56503f98ef 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -337,6 +337,8 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + model.resize_token_embeddings(len(tokenizer)) + # Set decoder_start_token_id if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if isinstance(tokenizer, MBartTokenizer):