Add POS tagging and Phrase chunking token classification examples (#6457)

* Add more token classification examples

* POS tagging example

* Phrase chunking example

* PR review fixes

* Add conllu to third party list (used in token classification examples)
This commit is contained in:
vblagoje
2020-08-13 12:09:51 -04:00
committed by GitHub
parent f51161e230
commit eda07efaa5
10 changed files with 473 additions and 204 deletions

View File

@@ -14,16 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
import logging
import os
import sys
from dataclasses import dataclass, field
from importlib import import_module
from typing import Dict, List, Optional, Tuple
import numpy as np
from seqeval.metrics import f1_score, precision_score, recall_score
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch import nn
from transformers import (
@@ -36,7 +35,7 @@ from transformers import (
TrainingArguments,
set_seed,
)
from utils_ner import NerDataset, Split, get_labels
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
logger = logging.getLogger(__name__)
@@ -54,6 +53,9 @@ class ModelArguments:
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
task_type: Optional[str] = field(
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
@@ -113,6 +115,16 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, model_args.task_type)
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -133,7 +145,7 @@ def main():
set_seed(training_args.seed)
# Prepare CONLL-2003 task
labels = get_labels(data_args.labels)
labels = token_classification_task.get_labels(data_args.labels)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)
@@ -164,7 +176,8 @@ def main():
# Get datasets
train_dataset = (
NerDataset(
TokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
@@ -177,7 +190,8 @@ def main():
else None
)
eval_dataset = (
NerDataset(
TokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
@@ -209,6 +223,7 @@ def main():
def compute_metrics(p: EvalPrediction) -> Dict:
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
return {
"accuracy_score": accuracy_score(out_label_list, preds_list),
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
@@ -253,7 +268,8 @@ def main():
# Predict
if training_args.do_predict:
test_dataset = NerDataset(
test_dataset = TokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
@@ -278,19 +294,7 @@ def main():
if trainer.is_world_master():
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
token_classification_task.write_predictions_to_file(writer, f, preds_list)
return results