multi-gpu cleanup

This commit is contained in:
thomwolf
2018-11-04 11:54:57 +01:00
parent 5ee171689c
commit 1701291ef9
2 changed files with 6 additions and 11 deletions

View File

@@ -539,9 +539,11 @@ def main():
label_ids = label_ids.to(device)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss = loss.mean() # sum() is to account for multi-gpu support.
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
model.zero_grad()
loss.backward()
optimizer.step()