From 5bfcd0485ece086ebcbed2d008813037968a9e58 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 4 Dec 2019 14:53:11 +0100 Subject: [PATCH] fix #1991 --- examples/run_lm_finetuning.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 0bb7460353..a5eaf524ac 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -217,7 +217,10 @@ def train(args, train_dataset, model, tokenizer): global_step = 0 tr_loss, logging_loss = 0.0, 0.0 - model.resize_token_embeddings(len(tokenizer)) + + model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_resize.resize_token_embeddings(len(tokenizer)) + model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility (even between python 2 and 3)