[HANS] Fix label_list for RoBERTa/BART (class flipping) (#5196)

* fix weirdness in roberta/bart for mnli trained checkpoints

* black compliance

* isort code check
This commit is contained in:
Victor SANH
2020-06-24 14:38:15 -04:00
committed by GitHub
parent fc24a93e64
commit 4965aee064
2 changed files with 41 additions and 6 deletions

View File

@@ -33,7 +33,7 @@ from transformers import (
default_data_collator,
set_seed,
)
from utils_hans import HansDataset, InputFeatures, hans_processors
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
logger = logging.getLogger(__name__)
@@ -130,9 +130,7 @@ def main():
set_seed(training_args.seed)
try:
processor = hans_processors[data_args.task_name]()
label_list = processor.get_labels()
num_labels = len(label_list)
num_labels = hans_tasks_num_labels[data_args.task_name]
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name))
@@ -214,6 +212,7 @@ def main():
pair_ids = [ex.pairID for ex in eval_dataset]
output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt")
label_list = eval_dataset.get_labels()
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
writer.write("pairID,gold_label\n")