From dec8f4d6fdc106ec63a30ad9ad33526be0675f5a Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Fri, 30 Aug 2019 13:52:18 -0400 Subject: [PATCH] Added DistilBERT models to all other AutoModels. --- pytorch_transformers/modeling_auto.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index b15a21c646..0c328909c2 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -30,12 +30,13 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, TransfoXLLMHea from .modeling_xlnet import XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering from .modeling_xlm import XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering from .modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification -from .modeling_distilbert import DistilBertConfig, DistilBertModel +from .modeling_distilbert import DistilBertConfig, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification from .modeling_utils import PreTrainedModel, SequenceSummary, add_start_docstrings logger = logging.getLogger(__name__) + class AutoConfig(object): r""":class:`~pytorch_transformers.AutoConfig` is a generic configuration class that will be instantiated as one of the configuration classes of the library @@ -47,6 +48,7 @@ class AutoConfig(object): The base model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `bert`: BertConfig (Bert model) - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) @@ -68,6 +70,7 @@ class AutoConfig(object): The configuration class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `bert`: BertConfig (Bert model) - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) @@ -151,6 +154,7 @@ class AutoModel(object): The base model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertModel (DistilBERT model) - contains `roberta`: RobertaModel (RoBERTa model) - contains `bert`: BertModel (Bert model) - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) @@ -172,6 +176,7 @@ class AutoModel(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertModel (DistilBERT model) - contains `roberta`: RobertaModel (RoBERTa model) - contains `bert`: BertModel (Bert model) - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) @@ -258,7 +263,6 @@ class AutoModel(object): "'xlm', 'roberta'".format(pretrained_model_name_or_path)) - class AutoModelWithLMHead(object): r""" :class:`~pytorch_transformers.AutoModelWithLMHead` is a generic model class @@ -271,6 +275,7 @@ class AutoModelWithLMHead(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) - contains `roberta`: RobertaForMaskedLM (RoBERTa model) - contains `bert`: BertForMaskedLM (Bert model) - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model) @@ -295,6 +300,7 @@ class AutoModelWithLMHead(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) - contains `roberta`: RobertaForMaskedLM (RoBERTa model) - contains `bert`: BertForMaskedLM (Bert model) - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model) @@ -359,7 +365,9 @@ class AutoModelWithLMHead(object): model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if 'roberta' in pretrained_model_name_or_path: + if 'distilbert' in pretrained_model_name_or_path: + return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'bert' in pretrained_model_name_or_path: return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @@ -391,6 +399,7 @@ class AutoModelForSequenceClassification(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model) - contains `roberta`: RobertaForSequenceClassification (RoBERTa model) - contains `bert`: BertForSequenceClassification (Bert model) - contains `xlnet`: XLNetForSequenceClassification (XLNet model) @@ -412,6 +421,7 @@ class AutoModelForSequenceClassification(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model) - contains `roberta`: RobertaForSequenceClassification (RoBERTa model) - contains `bert`: BertForSequenceClassification (Bert model) - contains `xlnet`: XLNetForSequenceClassification (XLNet model) @@ -473,7 +483,9 @@ class AutoModelForSequenceClassification(object): model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if 'roberta' in pretrained_model_name_or_path: + if 'distilbert' in pretrained_model_name_or_path: + return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'bert' in pretrained_model_name_or_path: return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @@ -498,6 +510,7 @@ class AutoModelForQuestionAnswering(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model) - contains `bert`: BertForQuestionAnswering (Bert model) - contains `xlnet`: XLNetForQuestionAnswering (XLNet model) - contains `xlm`: XLMForQuestionAnswering (XLM model) @@ -518,6 +531,7 @@ class AutoModelForQuestionAnswering(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model) - contains `bert`: BertForQuestionAnswering (Bert model) - contains `xlnet`: XLNetForQuestionAnswering (XLNet model) - contains `xlm`: XLMForQuestionAnswering (XLM model) @@ -578,7 +592,9 @@ class AutoModelForQuestionAnswering(object): model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if 'bert' in pretrained_model_name_or_path: + if 'distilbert' in pretrained_model_name_or_path: + return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'bert' in pretrained_model_name_or_path: return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'xlnet' in pretrained_model_name_or_path: return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)