Quick fix on eval accuracy
This commit is contained in:
@@ -548,6 +548,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
eval_loss = 0
|
eval_loss = 0
|
||||||
eval_accuracy = 0
|
eval_accuracy = 0
|
||||||
|
nb_eval_examples = 0
|
||||||
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
@@ -562,9 +563,11 @@ def main():
|
|||||||
|
|
||||||
eval_loss += tmp_eval_loss.item()
|
eval_loss += tmp_eval_loss.item()
|
||||||
eval_accuracy += tmp_eval_accuracy
|
eval_accuracy += tmp_eval_accuracy
|
||||||
|
|
||||||
|
nb_eval_examples += input_ids.size(0)
|
||||||
|
|
||||||
eval_loss = eval_loss / len(eval_dataloader)
|
eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader)
|
||||||
eval_accuracy = eval_accuracy / len(eval_dataloader)
|
eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader)
|
||||||
|
|
||||||
result = {'eval_loss': eval_loss,
|
result = {'eval_loss': eval_loss,
|
||||||
'eval_accuracy': eval_accuracy,
|
'eval_accuracy': eval_accuracy,
|
||||||
|
|||||||
Reference in New Issue
Block a user