From baf08ca1d4ab5aee1d530fc1801370e8a81cc091 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 13 Aug 2019 12:51:15 -0400 Subject: [PATCH] [RoBERTa] run_glue: correct pad_token + reorder labels --- examples/run_glue.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index f6cd73ed0b..445a9a5912 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -268,6 +268,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): else: logger.info("Creating features from dataset file at %s", args.data_dir) label_list = processor.get_labels() + if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: + # HACK(label indices are swapped in RoBERTa pretrained model) + label_list[1], label_list[2] = label_list[2], label_list[1] examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end @@ -276,7 +279,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): sep_token=tokenizer.sep_token, sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet - pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) + pad_token=1 if args.model_type in ['roberta'] else 0, # TODO(Lysandre: replace with tokenizer.pad_token when implemented) + pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, + ) if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file)