From e24e19ce3bbbc3fe317e4d277b919cd1cb31fc47 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 15 Aug 2019 14:02:11 -0400 Subject: [PATCH] Added RoBERTa to AutoModel/AutoConfig --- pytorch_transformers/modeling_auto.py | 33 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index 64b151e3a3..47c37a57d6 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -29,6 +29,7 @@ from .modeling_gpt2 import GPT2Config, GPT2Model from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel from .modeling_xlnet import XLNetConfig, XLNetModel from .modeling_xlm import XLMConfig, XLMModel +from .modeling_roberta import RobertaConfig, RobertaModel from .modeling_utils import PreTrainedModel, SequenceSummary @@ -51,6 +52,7 @@ class AutoConfig(object): - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - contains `xlnet`: XLNetConfig (XLNet model) - contains `xlm`: XLMConfig (XLM model) + - contains `roberta`: RobertaConfig (RoBERTa model) This class cannot be instantiated using `__init__()` (throw an error). """ @@ -71,6 +73,7 @@ class AutoConfig(object): - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - contains `xlnet`: XLNetConfig (XLNet model) - contains `xlm`: XLMConfig (XLM model) + - contains `roberta`: RobertaConfig (RoBERTa model) Params: **pretrained_model_name_or_path**: either: @@ -119,6 +122,8 @@ class AutoConfig(object): return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: + return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " @@ -137,12 +142,13 @@ class AutoModel(object): The base model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) + - contains `bert`: BertModel (Bert model) + - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) + - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) + - contains `xlnet`: XLNetModel (XLNet model) + - contains `xlm`: XLMModel (XLM model) + - contains `roberta`: RobertaModel (RoBERTa model) This class cannot be instantiated using `__init__()` (throw an error). """ @@ -157,12 +163,13 @@ class AutoModel(object): The base model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) + - contains `bert`: BertModel (Bert model) + - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) + - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) + - contains `xlnet`: XLNetModel (XLNet model) + - contains `xlm`: XLMModel (XLM model) + - contains `roberta`: RobertaModel (RoBERTa model) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) To train the model, you should first set it back in training mode with `model.train()` @@ -230,6 +237,8 @@ class AutoModel(object): return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: + return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "