multi-gpu cleanup
This commit is contained in:
@@ -539,9 +539,11 @@ def main():
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
loss = loss.mean() # sum() is to account for multi-gpu support.
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
Reference in New Issue
Block a user