Switch metrics in run_ner to datasets (#9567)
* Switch metrics in run_ner to datasets * Add flag to return all metrics * Upstream (and rename) sortish_sampler * Revert "Upstream (and rename) sortish_sampler" This reverts commit e07d0dcf650c2bae36da011dd76c77a8bb4feb0d.
This commit is contained in:
@@ -184,7 +184,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_ner.main()
|
result = run_ner.main()
|
||||||
self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
self.assertGreaterEqual(result["eval_precision"], 0.75)
|
self.assertGreaterEqual(result["eval_precision"], 0.75)
|
||||||
self.assertLess(result["eval_loss"], 0.5)
|
self.assertLess(result["eval_loss"], 0.5)
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import ClassLabel, load_dataset
|
from datasets import ClassLabel, load_dataset, load_metric
|
||||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -124,6 +123,10 @@ class DataTrainingArguments:
|
|||||||
"one (in which case the other tokens will have a padding index)."
|
"one (in which case the other tokens will have a padding index)."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
return_entity_level_metrics: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||||
@@ -323,6 +326,8 @@ def main():
|
|||||||
data_collator = DataCollatorForTokenClassification(tokenizer)
|
data_collator = DataCollatorForTokenClassification(tokenizer)
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
|
metric = load_metric("seqeval")
|
||||||
|
|
||||||
def compute_metrics(p):
|
def compute_metrics(p):
|
||||||
predictions, labels = p
|
predictions, labels = p
|
||||||
predictions = np.argmax(predictions, axis=2)
|
predictions = np.argmax(predictions, axis=2)
|
||||||
@@ -337,12 +342,24 @@ def main():
|
|||||||
for prediction, label in zip(predictions, labels)
|
for prediction, label in zip(predictions, labels)
|
||||||
]
|
]
|
||||||
|
|
||||||
return {
|
results = metric.compute(predictions=true_predictions, references=true_labels)
|
||||||
"accuracy_score": accuracy_score(true_labels, true_predictions),
|
if data_args.return_entity_level_metrics:
|
||||||
"precision": precision_score(true_labels, true_predictions),
|
# Unpack nested dictionaries
|
||||||
"recall": recall_score(true_labels, true_predictions),
|
final_results = {}
|
||||||
"f1": f1_score(true_labels, true_predictions),
|
for key, value in results.items():
|
||||||
}
|
if isinstance(value, dict):
|
||||||
|
for n, v in value.items():
|
||||||
|
final_results[f"{key}_{n}"] = v
|
||||||
|
else:
|
||||||
|
final_results[key] = value
|
||||||
|
return final_results
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"precision": results["overall_precision"],
|
||||||
|
"recall": results["overall_recall"],
|
||||||
|
"f1": results["overall_f1"],
|
||||||
|
"accuracy": results["overall_accuracy"],
|
||||||
|
}
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|||||||
Reference in New Issue
Block a user