diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 943b1be18c..569a8b4585 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -548,6 +548,7 @@ def main(): model.eval() eval_loss = 0 eval_accuracy = 0 + nb_eval_examples = 0 for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) @@ -562,9 +563,11 @@ def main(): eval_loss += tmp_eval_loss.item() eval_accuracy += tmp_eval_accuracy + + nb_eval_examples += input_ids.size(0) - eval_loss = eval_loss / len(eval_dataloader) - eval_accuracy = eval_accuracy / len(eval_dataloader) + eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader) + eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader) result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy,