From 51397336234a56ed2169413385c097fa1db4532d Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 26 May 2020 01:33:55 +0530 Subject: [PATCH] LongformerTokenizerFast (#4547) --- src/transformers/__init__.py | 2 +- src/transformers/tokenization_longformer.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3c034d1768..35cdd5dde1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -139,7 +139,7 @@ from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFas from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast -from .tokenization_longformer import LongformerTokenizer +from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_reformer import ReformerTokenizer from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast diff --git a/src/transformers/tokenization_longformer.py b/src/transformers/tokenization_longformer.py index 5d597935e2..7ac2a00901 100644 --- a/src/transformers/tokenization_longformer.py +++ b/src/transformers/tokenization_longformer.py @@ -15,7 +15,7 @@ import logging -from .tokenization_roberta import RobertaTokenizer +from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast logger = logging.getLogger(__name__) @@ -40,3 +40,12 @@ class LongformerTokenizer(RobertaTokenizer): "vocab_file": {m: vocab_url for m in _all_longformer_models}, "merges_file": {m: merges_url for m in _all_longformer_models}, } + + +class LongformerTokenizerFast(RobertaTokenizerFast): + # merges and vocab same as Roberta + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_vocab_files_map = { + "vocab_file": {m: vocab_url for m in _all_longformer_models}, + "merges_file": {m: merges_url for m in _all_longformer_models}, + }