fix #1991
This commit is contained in:
@@ -217,7 +217,10 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
|
||||||
|
model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
|
model_to_resize.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||||
|
|||||||
Reference in New Issue
Block a user