Add POS tagging and Phrase chunking token classification examples (#6457)

* Add more token classification examples

* POS tagging example

* Phrase chunking example

* PR review fixes

* Add conllu to third party list (used in token classification examples)
This commit is contained in:
vblagoje
2020-08-13 12:09:51 -04:00
committed by GitHub
parent f51161e230
commit eda07efaa5
10 changed files with 473 additions and 204 deletions

View File

@@ -2,15 +2,17 @@ import argparse
import glob
import logging
import os
from argparse import Namespace
from importlib import import_module
import numpy as np
import torch
from seqeval.metrics import f1_score, precision_score, recall_score
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, TensorDataset
from lightning_base import BaseTransformer, add_generic_args, generic_train
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
from utils_ner import TokenClassificationTask
logger = logging.getLogger(__name__)
@@ -24,10 +26,20 @@ class NERTransformer(BaseTransformer):
mode = "token-classification"
def __init__(self, hparams):
self.labels = get_labels(hparams.labels)
num_labels = len(self.labels)
if type(hparams) == dict:
hparams = Namespace(**hparams)
module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, hparams.task_type)
self.token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {hparams.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)
self.labels = self.token_classification_task.get_labels(hparams.labels)
self.pad_token_label_id = CrossEntropyLoss().ignore_index
super().__init__(hparams, num_labels, self.mode)
super().__init__(hparams, len(self.labels), self.mode)
def forward(self, **inputs):
return self.model(**inputs)
@@ -42,8 +54,8 @@ class NERTransformer(BaseTransformer):
outputs = self(**inputs)
loss = outputs[0]
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
return {"loss": loss, "log": tensorboard_logs}
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
return {"loss": loss}
def prepare_data(self):
"Called to initialize data. Use the call to construct features"
@@ -55,8 +67,8 @@ class NERTransformer(BaseTransformer):
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
examples = read_examples_from_file(args.data_dir, mode)
features = convert_examples_to_features(
examples = self.token_classification_task.read_examples_from_file(args.data_dir, mode)
features = self.token_classification_task.convert_examples_to_features(
examples,
self.labels,
args.max_seq_length,
@@ -74,7 +86,7 @@ class NERTransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
def load_dataset(self, mode, batch_size):
def get_dataloader(self, mode: int, batch_size: int) -> DataLoader:
"Load datasets. Called after prepare data."
cached_features_file = self._feature_file(mode)
logger.info("Loading features from cached file %s", cached_features_file)
@@ -124,6 +136,7 @@ class NERTransformer(BaseTransformer):
results = {
"val_loss": val_loss_mean,
"accuracy_score": accuracy_score(out_label_list, preds_list),
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
@@ -154,6 +167,9 @@ class NERTransformer(BaseTransformer):
def add_model_specific_args(parser, root_dir):
# Add NER specific options
BaseTransformer.add_model_specific_args(parser, root_dir)
parser.add_argument(
"--task_type", default="NER", type=str, help="Task type to fine tune in training (e.g. NER, POS, etc)"
)
parser.add_argument(
"--max_seq_length",
default=128,