diff --git a/examples/run_language_modeling.py b/examples/run_language_modeling.py index 0890598d5e..2b0163d96a 100644 --- a/examples/run_language_modeling.py +++ b/examples/run_language_modeling.py @@ -233,6 +233,9 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke else: t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + model = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training + model.resize_token_embeddings(len(tokenizer)) + # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ @@ -309,9 +312,6 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke tr_loss, logging_loss = 0.0, 0.0 - model_to_resize = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training - model_to_resize.resize_token_embeddings(len(tokenizer)) - model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]