take path to pretrained for encoder and decoder for init
This commit is contained in:
@@ -21,21 +21,20 @@ import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.Seq2seq` is a generic model class
|
||||
that will be instantiated as a Seq2seq model with one of the base model classes of the library
|
||||
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||
instantiated as a Seq2seq model with one of the base model classes of
|
||||
the library as encoder and (optionally) as decoder when created with
|
||||
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
|
||||
method.
|
||||
"""
|
||||
def __init__(self, encoder, decoder):
|
||||
super(PreTrainedSeq2seq, self).__init__()
|
||||
@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
self.decoder = decoder
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r""" Instantiates one of the base model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
# Extract encoder and decoder model if provided
|
||||
encoder_model = kwargs.pop('encoder_model', None)
|
||||
decoder_model = kwargs.pop('decoder_model', None)
|
||||
|
||||
# Extract decoder kwargs so we only have encoder kwargs for now
|
||||
if decoder_model is None:
|
||||
decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path)
|
||||
decoder_kwargs = {}
|
||||
for key in kwargs.keys():
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
kwargs_decoder = {}
|
||||
kwargs_encoder = kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
if key.startswith('decoder_'):
|
||||
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
|
||||
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
|
||||
|
||||
# Load and initialize the decoder
|
||||
if encoder_model:
|
||||
encoder = encoder_model
|
||||
else:
|
||||
# Load and initialize the encoder
|
||||
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
|
||||
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs.pop('encoder_model', None)
|
||||
if encoder is None:
|
||||
kwargs_encoder['is_decoder'] = False
|
||||
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||
|
||||
# Load and initialize the decoder
|
||||
if decoder_model:
|
||||
decoder = decoder_model
|
||||
else:
|
||||
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
||||
kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
|
||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
decoder = kwargs.pop('decoder_model', None)
|
||||
if decoder is None:
|
||||
kwargs_decoder['is_decoder'] = True
|
||||
decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
model = cls(encoder, decoder)
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, *inputs, *kwargs):
|
||||
def forward(self, *inputs, **kwargs):
|
||||
# Extract decoder inputs
|
||||
decoder_kwargs = {}
|
||||
for key in kwargs.keys():
|
||||
|
||||
Reference in New Issue
Block a user