From b7141a1bc604b8f9512f89d8dc3ec9dcc062e895 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 14 Oct 2019 12:14:08 +0200 Subject: [PATCH] maxi simplication --- transformers/modeling_seq2seq.py | 75 ++------------------------------ 1 file changed, 3 insertions(+), 72 deletions(-) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index e8106f47f5..12792c6e7a 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -21,14 +21,7 @@ import logging import torch from torch import nn -from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering -from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel -from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel -from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel -from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering -from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering -from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification -from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification +from .modeling_auto import AutoModel, AutoModelWithLMHead from .modeling_utils import PreTrainedModel, SequenceSummary @@ -43,22 +36,6 @@ class PreTrainedSeq2seq(nn.Module): that will be instantiated as a Seq2seq model with one of the base model classes of the library as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class method. - - The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. - - The base model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `distilbert`: DistilBertModel (DistilBERT model) - - contains `roberta`: RobertaModel (RoBERTa model) - - contains `bert`: BertModel (Bert model) - - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) - - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) - - contains `xlnet`: XLNetModel (XLNet model) - - contains `xlm`: XLMModel (XLM model) - - This class cannot be instantiated using `__init__()` (throws an error). """ def __init__(self, encoder, decoder): super(PreTrainedSeq2seq, self).__init__() @@ -69,18 +46,6 @@ class PreTrainedSeq2seq(nn.Module): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r""" Instantiates one of the base model classes of the library from a pre-trained model configuration. - - The model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `distilbert`: DistilBertModel (DistilBERT model) - - contains `roberta`: RobertaModel (RoBERTa model) - - contains `bert`: BertModel (Bert model) - - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) - - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) - - contains `xlnet`: XLNetModel (XLNet model) - - contains `xlm`: XLMModel (XLM model) - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) To train the model, you should first set it back in training mode with `model.train()` @@ -155,26 +120,7 @@ class PreTrainedSeq2seq(nn.Module): else: # Load and initialize the encoder kwargs['is_decoder'] = False # Make sure the encoder will be an encoder - if 'distilbert' in pretrained_model_name_or_path: - encoder = DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'roberta' in pretrained_model_name_or_path: - encoder = RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'bert' in pretrained_model_name_or_path: - encoder = BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'openai-gpt' in pretrained_model_name_or_path: - encoder = OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'gpt2' in pretrained_model_name_or_path: - encoder = GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'transfo-xl' in pretrained_model_name_or_path: - encoder = TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'xlnet' in pretrained_model_name_or_path: - encoder = XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'xlm' in pretrained_model_name_or_path: - encoder = XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - else: - raise ValueError("Unrecognized model identifier in {}. Should contains one of " - "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm', 'roberta'".format(pretrained_model_name_or_path)) + encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) # Load and initialize the decoder if decoder_model: @@ -182,22 +128,7 @@ class PreTrainedSeq2seq(nn.Module): else: kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc... kwargs['is_decoder'] = True # Make sure the decoder will be an decoder - if 'distilbert' in decoder_pretrained_model_name_or_path: - decoder = DistilBertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) - elif 'roberta' in decoder_pretrained_model_name_or_path: - decoder = RobertaModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) - elif 'bert' in decoder_pretrained_model_name_or_path: - decoder = BertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) - elif 'openai-gpt' in decoder_pretrained_model_name_or_path: - decoder = OpenAIGPTModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) - elif 'gpt2' in decoder_pretrained_model_name_or_path: - decoder = GPT2Model.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) - elif 'transfo-xl' in decoder_pretrained_model_name_or_path: - decoder = TransfoXLModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) - elif 'xlnet' in decoder_pretrained_model_name_or_path: - decoder = XLNetModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) - elif 'xlm' in decoder_pretrained_model_name_or_path: - decoder = XLMModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) else: raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "