From b54de837c2244c7dac6c058429ebe3cd7efe4dcc Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 02:46:17 -0400 Subject: [PATCH] Quick fix on eval accuracy --- run_classifier_pytorch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 943b1be18c..569a8b4585 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -548,6 +548,7 @@ def main(): model.eval() eval_loss = 0 eval_accuracy = 0 + nb_eval_examples = 0 for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) @@ -562,9 +563,11 @@ def main(): eval_loss += tmp_eval_loss.item() eval_accuracy += tmp_eval_accuracy + + nb_eval_examples += input_ids.size(0) - eval_loss = eval_loss / len(eval_dataloader) - eval_accuracy = eval_accuracy / len(eval_dataloader) + eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader) + eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader) result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy,