Implement fine-tuning BERT on CoNLL-2003 named entity recognition task
This commit is contained in:
committed by
thomwolf
parent
5adb39e757
commit
383ef96747
@@ -51,8 +51,13 @@ class InputFeatures(object):
|
||||
self.label_ids = label_ids
|
||||
|
||||
|
||||
def read_examples_from_file(data_dir, mode):
|
||||
file_path = os.path.join(data_dir, "{}.txt".format(mode))
|
||||
def read_examples_from_file(data_dir, evaluate=False):
|
||||
if evaluate:
|
||||
file_path = os.path.join(data_dir, "dev.txt")
|
||||
guid_prefix = "dev"
|
||||
else:
|
||||
file_path = os.path.join(data_dir, "train.txt")
|
||||
guid_prefix = "train"
|
||||
guid_index = 1
|
||||
examples = []
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
@@ -61,7 +66,7 @@ def read_examples_from_file(data_dir, mode):
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
if words:
|
||||
examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
|
||||
examples.append(InputExample(guid="{}-{}".format(guid_prefix, guid_index),
|
||||
words=words,
|
||||
labels=labels))
|
||||
guid_index += 1
|
||||
@@ -70,13 +75,9 @@ def read_examples_from_file(data_dir, mode):
|
||||
else:
|
||||
splits = line.split(" ")
|
||||
words.append(splits[0])
|
||||
if len(splits) > 1:
|
||||
labels.append(splits[-1].replace("\n", ""))
|
||||
else:
|
||||
# Examples could have no label for mode = "test"
|
||||
labels.append("O")
|
||||
labels.append(splits[-1][:-1])
|
||||
if words:
|
||||
examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
|
||||
examples.append(InputExample(guid="%s-%d".format(guid_prefix, guid_index),
|
||||
words=words,
|
||||
labels=labels))
|
||||
return examples
|
||||
@@ -201,12 +202,5 @@ def convert_examples_to_features(examples,
|
||||
return features
|
||||
|
||||
|
||||
def get_labels(path):
|
||||
if path:
|
||||
with open(path, "r") as f:
|
||||
labels = f.read().splitlines()
|
||||
if "O" not in labels:
|
||||
labels = ["O"] + labels
|
||||
return labels
|
||||
else:
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||
def get_labels():
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||
|
||||
Reference in New Issue
Block a user