From 89d47230d74061d57b700eed056f6a6763315610 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Nov 2018 22:54:53 +0100 Subject: [PATCH] clean up classification model output --- examples/run_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 52f3cd752d..23c2bea057 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -546,7 +546,7 @@ def main(): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch - loss, _ = model(input_ids, segment_ids, input_mask, label_ids) + loss = model(input_ids, segment_ids, input_mask, label_ids) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.fp16 and args.loss_scale != 1.0: