Add POS tagging and Phrase chunking token classification examples (#6457)
* Add more token classification examples * POS tagging example * Phrase chunking example * PR review fixes * Add conllu to third party list (used in token classification examples)
This commit is contained in:
@@ -14,16 +14,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from seqeval.metrics import f1_score, precision_score, recall_score
|
||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
@@ -36,7 +35,7 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from utils_ner import NerDataset, Split, get_labels
|
||||
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -54,6 +53,9 @@ class ModelArguments:
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
task_type: Optional[str] = field(
|
||||
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
@@ -113,6 +115,16 @@ def main():
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
module = import_module("tasks")
|
||||
try:
|
||||
token_classification_task_clazz = getattr(module, model_args.task_type)
|
||||
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
|
||||
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -133,7 +145,7 @@ def main():
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Prepare CONLL-2003 task
|
||||
labels = get_labels(data_args.labels)
|
||||
labels = token_classification_task.get_labels(data_args.labels)
|
||||
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
|
||||
num_labels = len(labels)
|
||||
|
||||
@@ -164,7 +176,8 @@ def main():
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
NerDataset(
|
||||
TokenClassificationDataset(
|
||||
token_classification_task=token_classification_task,
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
@@ -177,7 +190,8 @@ def main():
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
NerDataset(
|
||||
TokenClassificationDataset(
|
||||
token_classification_task=token_classification_task,
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
@@ -209,6 +223,7 @@ def main():
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
|
||||
return {
|
||||
"accuracy_score": accuracy_score(out_label_list, preds_list),
|
||||
"precision": precision_score(out_label_list, preds_list),
|
||||
"recall": recall_score(out_label_list, preds_list),
|
||||
"f1": f1_score(out_label_list, preds_list),
|
||||
@@ -253,7 +268,8 @@ def main():
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
test_dataset = NerDataset(
|
||||
test_dataset = TokenClassificationDataset(
|
||||
token_classification_task=token_classification_task,
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
@@ -278,19 +294,7 @@ def main():
|
||||
if trainer.is_world_master():
|
||||
with open(output_test_predictions_file, "w") as writer:
|
||||
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
|
||||
example_id = 0
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
writer.write(line)
|
||||
if not preds_list[example_id]:
|
||||
example_id += 1
|
||||
elif preds_list[example_id]:
|
||||
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
|
||||
writer.write(output_line)
|
||||
else:
|
||||
logger.warning(
|
||||
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
|
||||
)
|
||||
token_classification_task.write_predictions_to_file(writer, f, preds_list)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user