From 39f426be6577d4534a058c9c42d52053a0ef9257 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 13 Aug 2019 15:19:50 -0400 Subject: [PATCH] Added special tokens and to RoBERTa. --- examples/run_glue.py | 2 +- pytorch_transformers/tokenization_roberta.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index 445a9a5912..c0f70e0863 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): sep_token=tokenizer.sep_token, sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet - pad_token=1 if args.model_type in ['roberta'] else 0, # TODO(Lysandre: replace with tokenizer.pad_token when implemented) + pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token], pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, ) if args.local_rank in [-1, 0]: diff --git a/pytorch_transformers/tokenization_roberta.py b/pytorch_transformers/tokenization_roberta.py index 8f5cecee8a..1db8013183 100644 --- a/pytorch_transformers/tokenization_roberta.py +++ b/pytorch_transformers/tokenization_roberta.py @@ -73,9 +73,10 @@ class RobertaTokenizer(PreTrainedTokenizer): max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, merges_file, errors='replace', bos_token="", eos_token="", sep_token="", - cls_token="", unk_token="", **kwargs): + cls_token="", unk_token="", pad_token='', mask_token='', **kwargs): super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, - sep_token=sep_token, cls_token=cls_token, **kwargs) + sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, + mask_token=mask_token, **kwargs) self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v: k for k, v in self.encoder.items()}