Fixing run_glue example with RoBERTa

This commit is contained in:
LysandreJik
2019-08-16 11:02:10 -04:00
committed by Lysandre Debut
parent ab05280666
commit 7e7fc53da5
2 changed files with 5 additions and 4 deletions

View File

@@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
sep_token=tokenizer.sep_token,
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token],
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
)
if args.local_rank in [-1, 0]:

View File

@@ -425,9 +425,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
special_tokens_count = 3 if sep_token_extra else 2
if len(tokens_a) > max_seq_length - special_tokens_count:
tokens_a = tokens_a[:(max_seq_length - special_tokens_count)]
# The convention in BERT is:
# (a) For sequence pairs: