reszie token embeds (#11524)
This commit is contained in:
@@ -353,6 +353,8 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
if model.config.decoder_start_token_id is None:
|
if model.config.decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||||
|
|
||||||
|
|||||||
@@ -337,6 +337,8 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Set decoder_start_token_id
|
# Set decoder_start_token_id
|
||||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||||
if isinstance(tokenizer, MBartTokenizer):
|
if isinstance(tokenizer, MBartTokenizer):
|
||||||
|
|||||||
Reference in New Issue
Block a user