Merge pull request #347 from jplehmann/feature/sst2-processor
Processor for SST-2 task
This commit is contained in:
@@ -196,6 +196,37 @@ class ColaProcessor(DataProcessor):
|
|||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class Sst2Processor(DataProcessor):
|
||||||
|
"""Processor for the SST-2 data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return ["0", "1"]
|
||||||
|
|
||||||
|
def _create_examples(self, lines, set_type):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
examples = []
|
||||||
|
for (i, line) in enumerate(lines):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
guid = "%s-%s" % (set_type, i)
|
||||||
|
text_a = line[0]
|
||||||
|
label = line[1]
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
@@ -401,10 +432,12 @@ def main():
|
|||||||
"cola": ColaProcessor,
|
"cola": ColaProcessor,
|
||||||
"mnli": MnliProcessor,
|
"mnli": MnliProcessor,
|
||||||
"mrpc": MrpcProcessor,
|
"mrpc": MrpcProcessor,
|
||||||
|
"sst-2": Sst2Processor,
|
||||||
}
|
}
|
||||||
|
|
||||||
num_labels_task = {
|
num_labels_task = {
|
||||||
"cola": 2,
|
"cola": 2,
|
||||||
|
"sst-2": 2,
|
||||||
"mnli": 3,
|
"mnli": 3,
|
||||||
"mrpc": 2,
|
"mrpc": 2,
|
||||||
}
|
}
|
||||||
@@ -597,7 +630,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
eval_loss, eval_accuracy = 0, 0
|
eval_loss, eval_accuracy = 0, 0
|
||||||
nb_eval_steps, nb_eval_examples = 0, 0
|
nb_eval_steps, nb_eval_examples = 0, 0
|
||||||
|
|
||||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.to(device)
|
input_mask = input_mask.to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user