From d8923270e6c497862f990a3c72e40cc1ddd01d4e Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Fri, 16 Aug 2019 15:58:19 -0400 Subject: [PATCH] Correct truncation for RoBERTa in 2-input GLUE --- examples/utils_glue.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/utils_glue.py b/examples/utils_glue.py index e1649fa5af..3e3f104672 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -422,8 +422,9 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokens_b = tokenizer.tokenize(example.text_b) # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. - # Account for [CLS], [SEP], [SEP] with "- 3" - _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) + # Account for [CLS], [SEP], [SEP] with "- 3". " -4" for RoBERTa. + special_tokens_count = 4 if sep_token_extra else 3 + _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) else: # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. special_tokens_count = 3 if sep_token_extra else 2