Fix two bugs: 1. Index of test data of SST-2. 2. Label index of MNLI data. (#4546)
This commit is contained in:
@@ -86,6 +86,15 @@ class GlueDataset(Dataset):
|
|||||||
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
label_list = self.processor.get_labels()
|
||||||
|
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
|
||||||
|
RobertaTokenizer,
|
||||||
|
RobertaTokenizerFast,
|
||||||
|
XLMRobertaTokenizer,
|
||||||
|
):
|
||||||
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
|
self.label_list = label_list
|
||||||
|
|
||||||
# Make sure only the first process in distributed training processes the dataset,
|
# Make sure only the first process in distributed training processes the dataset,
|
||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
@@ -100,14 +109,7 @@ class GlueDataset(Dataset):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
||||||
label_list = self.processor.get_labels()
|
|
||||||
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
|
|
||||||
RobertaTokenizer,
|
|
||||||
RobertaTokenizerFast,
|
|
||||||
XLMRobertaTokenizer,
|
|
||||||
):
|
|
||||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
|
||||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
|
||||||
if mode == Split.dev:
|
if mode == Split.dev:
|
||||||
examples = self.processor.get_dev_examples(args.data_dir)
|
examples = self.processor.get_dev_examples(args.data_dir)
|
||||||
elif mode == Split.test:
|
elif mode == Split.test:
|
||||||
@@ -137,4 +139,4 @@ class GlueDataset(Dataset):
|
|||||||
return self.features[i]
|
return self.features[i]
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
return self.processor.get_labels()
|
return self.label_list
|
||||||
|
|||||||
@@ -332,11 +332,12 @@ class Sst2Processor(DataProcessor):
|
|||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training, dev and test sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
|
text_index = 1 if set_type == "test" else 0
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
continue
|
continue
|
||||||
guid = "%s-%s" % (set_type, i)
|
guid = "%s-%s" % (set_type, i)
|
||||||
text_a = line[0]
|
text_a = line[text_index]
|
||||||
label = None if set_type == "test" else line[1]
|
label = None if set_type == "test" else line[1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|||||||
Reference in New Issue
Block a user