[HANS] Fix label_list for RoBERTa/BART (class flipping) (#5196)
* fix weirdness in roberta/bart for mnli trained checkpoints * black compliance * isort code check
This commit is contained in:
@@ -33,7 +33,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from utils_hans import HansDataset, InputFeatures, hans_processors
|
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -130,9 +130,7 @@ def main():
|
|||||||
set_seed(training_args.seed)
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
processor = hans_processors[data_args.task_name]()
|
num_labels = hans_tasks_num_labels[data_args.task_name]
|
||||||
label_list = processor.get_labels()
|
|
||||||
num_labels = len(label_list)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError("Task not found: %s" % (data_args.task_name))
|
raise ValueError("Task not found: %s" % (data_args.task_name))
|
||||||
|
|
||||||
@@ -214,6 +212,7 @@ def main():
|
|||||||
|
|
||||||
pair_ids = [ex.pairID for ex in eval_dataset]
|
pair_ids = [ex.pairID for ex in eval_dataset]
|
||||||
output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt")
|
output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt")
|
||||||
|
label_list = eval_dataset.get_labels()
|
||||||
if trainer.is_world_master():
|
if trainer.is_world_master():
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
writer.write("pairID,gold_label\n")
|
writer.write("pairID,gold_label\n")
|
||||||
|
|||||||
@@ -22,7 +22,17 @@ from typing import List, Optional, Union
|
|||||||
import tqdm
|
import tqdm
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from transformers import DataProcessor, PreTrainedTokenizer, is_tf_available, is_torch_available
|
from transformers import (
|
||||||
|
BartTokenizer,
|
||||||
|
BartTokenizerFast,
|
||||||
|
DataProcessor,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
RobertaTokenizer,
|
||||||
|
RobertaTokenizerFast,
|
||||||
|
XLMRobertaTokenizer,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -105,6 +115,17 @@ if is_torch_available():
|
|||||||
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(max_seq_length), task,
|
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(max_seq_length), task,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
label_list = processor.get_labels()
|
||||||
|
if tokenizer.__class__ in (
|
||||||
|
RobertaTokenizer,
|
||||||
|
RobertaTokenizerFast,
|
||||||
|
XLMRobertaTokenizer,
|
||||||
|
BartTokenizer,
|
||||||
|
BartTokenizerFast,
|
||||||
|
):
|
||||||
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
|
self.label_list = label_list
|
||||||
|
|
||||||
# Make sure only the first process in distributed training processes the dataset,
|
# Make sure only the first process in distributed training processes the dataset,
|
||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
@@ -116,7 +137,6 @@ if is_torch_available():
|
|||||||
self.features = torch.load(cached_features_file)
|
self.features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||||
label_list = processor.get_labels()
|
|
||||||
|
|
||||||
examples = (
|
examples = (
|
||||||
processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||||
@@ -133,6 +153,9 @@ if is_torch_available():
|
|||||||
def __getitem__(self, i) -> InputFeatures:
|
def __getitem__(self, i) -> InputFeatures:
|
||||||
return self.features[i]
|
return self.features[i]
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
return self.label_list
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -156,6 +179,16 @@ if is_tf_available():
|
|||||||
):
|
):
|
||||||
processor = hans_processors[task]()
|
processor = hans_processors[task]()
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
|
if tokenizer.__class__ in (
|
||||||
|
RobertaTokenizer,
|
||||||
|
RobertaTokenizerFast,
|
||||||
|
XLMRobertaTokenizer,
|
||||||
|
BartTokenizer,
|
||||||
|
BartTokenizerFast,
|
||||||
|
):
|
||||||
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
|
self.label_list = label_list
|
||||||
|
|
||||||
examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||||
@@ -206,6 +239,9 @@ if is_tf_available():
|
|||||||
def __getitem__(self, i) -> InputFeatures:
|
def __getitem__(self, i) -> InputFeatures:
|
||||||
return self.features[i]
|
return self.features[i]
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
return self.label_list
|
||||||
|
|
||||||
|
|
||||||
class HansProcessor(DataProcessor):
|
class HansProcessor(DataProcessor):
|
||||||
"""Processor for the HANS data set."""
|
"""Processor for the HANS data set."""
|
||||||
|
|||||||
Reference in New Issue
Block a user