Quick fix on eval accuracy
This commit is contained in:
@@ -548,6 +548,7 @@ def main():
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
eval_accuracy = 0
|
||||
nb_eval_examples = 0
|
||||
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
@@ -562,9 +563,11 @@ def main():
|
||||
|
||||
eval_loss += tmp_eval_loss.item()
|
||||
eval_accuracy += tmp_eval_accuracy
|
||||
|
||||
nb_eval_examples += input_ids.size(0)
|
||||
|
||||
eval_loss = eval_loss / len(eval_dataloader)
|
||||
eval_accuracy = eval_accuracy / len(eval_dataloader)
|
||||
eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader)
|
||||
eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader)
|
||||
|
||||
result = {'eval_loss': eval_loss,
|
||||
'eval_accuracy': eval_accuracy,
|
||||
|
||||
Reference in New Issue
Block a user