From 455a4c842c9137fdda0548e80b5cbb766643b76c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 22:20:51 +0200 Subject: [PATCH] add distilbert tokenizer --- pytorch_transformers/tokenization_auto.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/tokenization_auto.py b/pytorch_transformers/tokenization_auto.py index 576dee70ec..447d360ca2 100644 --- a/pytorch_transformers/tokenization_auto.py +++ b/pytorch_transformers/tokenization_auto.py @@ -25,6 +25,7 @@ from .tokenization_transfo_xl import TransfoXLTokenizer from .tokenization_xlnet import XLNetTokenizer from .tokenization_xlm import XLMTokenizer from .tokenization_roberta import RobertaTokenizer +from.tokenization_distilbert import DistilBertTokenizer logger = logging.getLogger(__name__) @@ -39,13 +40,14 @@ class AutoTokenizer(object): The tokenizer class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertTokenizer (DistilBert model) + - contains `roberta`: RobertaTokenizer (RoBERTa model) - contains `bert`: BertTokenizer (Bert model) - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - contains `xlnet`: XLNetTokenizer (XLNet model) - contains `xlm`: XLMTokenizer (XLM model) - - contains `roberta`: RobertaTokenizer (RoBERTa model) This class cannot be instantiated using `__init__()` (throw an error). """ @@ -60,13 +62,14 @@ class AutoTokenizer(object): The tokenizer class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertTokenizer (DistilBert model) + - contains `roberta`: RobertaTokenizer (XLM model) - contains `bert`: BertTokenizer (Bert model) - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - contains `xlnet`: XLNetTokenizer (XLNet model) - contains `xlm`: XLMTokenizer (XLM model) - - contains `roberta`: RobertaTokenizer (XLM model) Params: pretrained_model_name_or_path: either: @@ -95,6 +98,8 @@ class AutoTokenizer(object): config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` """ + if 'distilbert' in pretrained_model_name_or_path: + return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) if 'roberta' in pretrained_model_name_or_path: return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'bert' in pretrained_model_name_or_path: