maxi simplication
This commit is contained in:
@@ -21,14 +21,7 @@ import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering
|
||||
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel
|
||||
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel
|
||||
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
|
||||
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
|
||||
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
|
||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
|
||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
@@ -43,22 +36,6 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
that will be instantiated as a Seq2seq model with one of the base model classes of the library
|
||||
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||
- contains `bert`: BertModel (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetModel (XLNet model)
|
||||
- contains `xlm`: XLMModel (XLM model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
def __init__(self, encoder, decoder):
|
||||
super(PreTrainedSeq2seq, self).__init__()
|
||||
@@ -69,18 +46,6 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r""" Instantiates one of the base model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||
- contains `bert`: BertModel (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetModel (XLNet model)
|
||||
- contains `xlm`: XLMModel (XLM model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
||||
@@ -155,26 +120,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
else:
|
||||
# Load and initialize the encoder
|
||||
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
encoder = DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
encoder = RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
encoder = BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
||||
encoder = OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'gpt2' in pretrained_model_name_or_path:
|
||||
encoder = GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
||||
encoder = TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
encoder = XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
encoder = XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
|
||||
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
# Load and initialize the decoder
|
||||
if decoder_model:
|
||||
@@ -182,22 +128,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
else:
|
||||
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
||||
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
|
||||
if 'distilbert' in decoder_pretrained_model_name_or_path:
|
||||
decoder = DistilBertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'roberta' in decoder_pretrained_model_name_or_path:
|
||||
decoder = RobertaModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'bert' in decoder_pretrained_model_name_or_path:
|
||||
decoder = BertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'openai-gpt' in decoder_pretrained_model_name_or_path:
|
||||
decoder = OpenAIGPTModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'gpt2' in decoder_pretrained_model_name_or_path:
|
||||
decoder = GPT2Model.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'transfo-xl' in decoder_pretrained_model_name_or_path:
|
||||
decoder = TransfoXLModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'xlnet' in decoder_pretrained_model_name_or_path:
|
||||
decoder = XLNetModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'xlm' in decoder_pretrained_model_name_or_path:
|
||||
decoder = XLMModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
|
||||
Reference in New Issue
Block a user