Add option to predict on test set

This commit is contained in:
Marianne Stecklina
2019-09-23 10:51:54 +02:00
committed by thomwolf
parent 7f5367e0b1
commit 5ff9cd158a
2 changed files with 46 additions and 19 deletions

View File

@@ -51,13 +51,8 @@ class InputFeatures(object):
self.label_ids = label_ids
def read_examples_from_file(data_dir, evaluate=False):
if evaluate:
file_path = os.path.join(data_dir, "dev.txt")
guid_prefix = "dev"
else:
file_path = os.path.join(data_dir, "train.txt")
guid_prefix = "train"
def read_examples_from_file(data_dir, mode):
file_path = os.path.join(data_dir, "{}.txt".format(mode))
guid_index = 1
examples = []
with open(file_path, encoding="utf-8") as f:
@@ -66,7 +61,7 @@ def read_examples_from_file(data_dir, evaluate=False):
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
if words:
examples.append(InputExample(guid="{}-{}".format(guid_prefix, guid_index),
examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
words=words,
labels=labels))
guid_index += 1
@@ -75,9 +70,13 @@ def read_examples_from_file(data_dir, evaluate=False):
else:
splits = line.split(" ")
words.append(splits[0])
labels.append(splits[-1].replace("\n", ""))
if len(splits) > 1:
labels.append(splits[-1].replace("\n", ""))
else:
# Examples could have no label for mode = "test"
labels.append("O")
if words:
examples.append(InputExample(guid="%s-%d".format(guid_prefix, guid_index),
examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
words=words,
labels=labels))
return examples