From 7e7fc53da5f230db379ece739457c81b2f50f13e Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Fri, 16 Aug 2019 11:02:10 -0400 Subject: [PATCH] Fixing run_glue example with RoBERTa --- examples/run_glue.py | 2 +- examples/utils_glue.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index c0f70e0863..7fb0732e61 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): sep_token=tokenizer.sep_token, sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet - pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token], + pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, ) if args.local_rank in [-1, 0]: diff --git a/examples/utils_glue.py b/examples/utils_glue.py index c955e4d0ce..e1649fa5af 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -425,9 +425,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, # Account for [CLS], [SEP], [SEP] with "- 3" _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) else: - # Account for [CLS] and [SEP] with "- 2" - if len(tokens_a) > max_seq_length - 2: - tokens_a = tokens_a[:(max_seq_length - 2)] + # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. + special_tokens_count = 3 if sep_token_extra else 2 + if len(tokens_a) > max_seq_length - special_tokens_count: + tokens_a = tokens_a[:(max_seq_length - special_tokens_count)] # The convention in BERT is: # (a) For sequence pairs: