This commit is contained in:
Davide Fiocco
2020-04-16 17:04:32 +02:00
committed by GitHub
parent baca8fa8e6
commit b1e2368b32

View File

@@ -162,7 +162,7 @@ def train(args, train_dataset, model, tokenizer):
train_iterator = trange( train_iterator = trange(
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
) )
set_seed(args) # Added here for reproductibility set_seed(args) # Added here for reproducibility
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):