cleaning up - speeding up a bit multi-gpu

This commit is contained in:
thomwolf
2018-11-07 22:22:55 +01:00
parent 6bb7510a50
commit dbc318a4c6
3 changed files with 7 additions and 6 deletions

View File

@@ -514,13 +514,13 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
@@ -564,7 +564,8 @@ def main():
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
with torch.no_grad():
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()