From e60e8a606837ff7f49e583de8492e55575155eb6 Mon Sep 17 00:00:00 2001 From: Davide Fiocco Date: Sun, 2 Dec 2018 12:38:26 +0100 Subject: [PATCH] Correct assignement for logits in classifier example I tried to address https://github.com/huggingface/pytorch-pretrained-BERT/issues/76 should be correct, but there's likely a more efficient way. --- examples/run_classifier.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 8e136da37b..a5e7d2c30d 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -605,7 +605,8 @@ def main(): label_ids = label_ids.to(device) with torch.no_grad(): - tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) + tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) + logits = model(input_ids, segment_ids, input_mask) logits = logits.detach().cpu().numpy() label_ids = label_ids.to('cpu').numpy()