@@ -13,7 +13,7 @@ streamlit
|
|||||||
elasticsearch
|
elasticsearch
|
||||||
nltk
|
nltk
|
||||||
pandas
|
pandas
|
||||||
datasets
|
datasets >= 1.1.3
|
||||||
fire
|
fire
|
||||||
pytest
|
pytest
|
||||||
conllu
|
conllu
|
||||||
|
|||||||
@@ -15,7 +15,8 @@
|
|||||||
"""
|
"""
|
||||||
Fine-tuning the library models for token classification.
|
Fine-tuning the library models for token classification.
|
||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as comments.
|
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as
|
||||||
|
# comments.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -24,7 +25,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import ClassLabel, load_dataset
|
||||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -198,12 +199,17 @@ def main():
|
|||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
column_names = datasets["train"].column_names
|
column_names = datasets["train"].column_names
|
||||||
|
features = datasets["train"].features
|
||||||
else:
|
else:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
text_column_name = "words" if "words" in column_names else column_names[0]
|
features = datasets["validation"].features
|
||||||
label_column_name = data_args.task_name if data_args.task_name in column_names else column_names[1]
|
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
|
||||||
|
label_column_name = (
|
||||||
|
f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
|
||||||
|
)
|
||||||
|
|
||||||
# Labeling (this part will be easier when https://github.com/huggingface/datasets/issues/797 is solved)
|
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
||||||
|
# unique labels.
|
||||||
def get_label_list(labels):
|
def get_label_list(labels):
|
||||||
unique_labels = set()
|
unique_labels = set()
|
||||||
for label in labels:
|
for label in labels:
|
||||||
@@ -212,8 +218,13 @@ def main():
|
|||||||
label_list.sort()
|
label_list.sort()
|
||||||
return label_list
|
return label_list
|
||||||
|
|
||||||
label_list = get_label_list(datasets["train"][label_column_name])
|
if isinstance(features[label_column_name].feature, ClassLabel):
|
||||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
label_list = features[label_column_name].feature.names
|
||||||
|
# No need to convert the labels since they are already ints.
|
||||||
|
label_to_id = {i: i for i in range(len(label_list))}
|
||||||
|
else:
|
||||||
|
label_list = get_label_list(datasets["train"][label_column_name])
|
||||||
|
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
|
|||||||
Reference in New Issue
Block a user