clean up classification model output
This commit is contained in:
@@ -546,7 +546,7 @@ def main():
|
|||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||||
batch = tuple(t.to(device) for t in batch)
|
batch = tuple(t.to(device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
if n_gpu > 1:
|
if n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu.
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
if args.fp16 and args.loss_scale != 1.0:
|
if args.fp16 and args.loss_scale != 1.0:
|
||||||
|
|||||||
Reference in New Issue
Block a user