From 3a5d1ea2a51964480a7b797e22f1d0f8053fba6b Mon Sep 17 00:00:00 2001 From: Zhangyx Date: Fri, 29 May 2020 23:12:24 +0800 Subject: [PATCH] Fix two bugs: 1. Index of test data of SST-2. 2. Label index of MNLI data. (#4546) --- src/transformers/data/datasets/glue.py | 20 +++++++++++--------- src/transformers/data/processors/glue.py | 3 ++- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/transformers/data/datasets/glue.py b/src/transformers/data/datasets/glue.py index eaaa40f628..2ee260ea9e 100644 --- a/src/transformers/data/datasets/glue.py +++ b/src/transformers/data/datasets/glue.py @@ -86,6 +86,15 @@ class GlueDataset(Dataset): mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, ), ) + label_list = self.processor.get_labels() + if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in ( + RobertaTokenizer, + RobertaTokenizerFast, + XLMRobertaTokenizer, + ): + # HACK(label indices are swapped in RoBERTa pretrained model) + label_list[1], label_list[2] = label_list[2], label_list[1] + self.label_list = label_list # Make sure only the first process in distributed training processes the dataset, # and the others will use the cache. @@ -100,14 +109,7 @@ class GlueDataset(Dataset): ) else: logger.info(f"Creating features from dataset file at {args.data_dir}") - label_list = self.processor.get_labels() - if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in ( - RobertaTokenizer, - RobertaTokenizerFast, - XLMRobertaTokenizer, - ): - # HACK(label indices are swapped in RoBERTa pretrained model) - label_list[1], label_list[2] = label_list[2], label_list[1] + if mode == Split.dev: examples = self.processor.get_dev_examples(args.data_dir) elif mode == Split.test: @@ -137,4 +139,4 @@ class GlueDataset(Dataset): return self.features[i] def get_labels(self): - return self.processor.get_labels() + return self.label_list diff --git a/src/transformers/data/processors/glue.py b/src/transformers/data/processors/glue.py index ecc43f4da4..870817a60e 100644 --- a/src/transformers/data/processors/glue.py +++ b/src/transformers/data/processors/glue.py @@ -332,11 +332,12 @@ class Sst2Processor(DataProcessor): def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] + text_index = 1 if set_type == "test" else 0 for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, i) - text_a = line[0] + text_a = line[text_index] label = None if set_type == "test" else line[1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples