From 83dba0b67bd8d142e830eab7aa6538b4dc50e1ef Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 15 Aug 2019 17:07:07 -0400 Subject: [PATCH] Added RoBERTa tokenizer to AutoTokenizer --- pytorch_transformers/modeling_auto.py | 4 ++-- pytorch_transformers/tokenization_auto.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index 47c37a57d6..7c96b7a287 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -127,7 +127,7 @@ class AutoConfig(object): raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm'".format(pretrained_model_name_or_path)) + "'xlm', 'roberta'".format(pretrained_model_name_or_path)) class AutoModel(object): @@ -242,4 +242,4 @@ class AutoModel(object): raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm'".format(pretrained_model_name_or_path)) + "'xlm', 'roberta'".format(pretrained_model_name_or_path)) diff --git a/pytorch_transformers/tokenization_auto.py b/pytorch_transformers/tokenization_auto.py index acbe1cebc6..adb8f87cd7 100644 --- a/pytorch_transformers/tokenization_auto.py +++ b/pytorch_transformers/tokenization_auto.py @@ -24,6 +24,7 @@ from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_transfo_xl import TransfoXLTokenizer from .tokenization_xlnet import XLNetTokenizer from .tokenization_xlm import XLMTokenizer +from .tokenization_roberta import RobertaTokenizer logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ class AutoTokenizer(object): - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - contains `xlnet`: XLNetTokenizer (XLNet model) - contains `xlm`: XLMTokenizer (XLM model) + - contains `roberta`: RobertaTokenizer (RoBERTa model) This class cannot be instantiated using `__init__()` (throw an error). """ @@ -64,6 +66,7 @@ class AutoTokenizer(object): - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - contains `xlnet`: XLNetTokenizer (XLNet model) - contains `xlm`: XLMTokenizer (XLM model) + - contains `roberta`: RobertaTokenizer (XLM model) Params: **pretrained_model_name_or_path**: either: @@ -94,7 +97,9 @@ class AutoTokenizer(object): return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: + return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm'".format(pretrained_model_name_or_path)) + "'xlm', 'roberta'".format(pretrained_model_name_or_path))