take path to pretrained for encoder and decoder for init
This commit is contained in:
@@ -21,21 +21,20 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||||
|
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedSeq2seq(nn.Module):
|
class PreTrainedSeq2seq(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
:class:`~transformers.Seq2seq` is a generic model class
|
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||||
that will be instantiated as a Seq2seq model with one of the base model classes of the library
|
instantiated as a Seq2seq model with one of the base model classes of
|
||||||
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
the library as encoder and (optionally) as decoder when created with
|
||||||
class method.
|
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
|
||||||
|
method.
|
||||||
"""
|
"""
|
||||||
def __init__(self, encoder, decoder):
|
def __init__(self, encoder, decoder):
|
||||||
super(PreTrainedSeq2seq, self).__init__()
|
super(PreTrainedSeq2seq, self).__init__()
|
||||||
@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
r""" Instantiates one of the base model classes of the library
|
r""" Instantiates one of the base model classes of the library
|
||||||
from a pre-trained model configuration.
|
from a pre-trained model configuration.
|
||||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Extract encoder and decoder model if provided
|
|
||||||
encoder_model = kwargs.pop('encoder_model', None)
|
|
||||||
decoder_model = kwargs.pop('decoder_model', None)
|
|
||||||
|
|
||||||
# Extract decoder kwargs so we only have encoder kwargs for now
|
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||||
if decoder_model is None:
|
# decoder-specific it the key starts with `decoder_`
|
||||||
decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path)
|
kwargs_decoder = {}
|
||||||
decoder_kwargs = {}
|
kwargs_encoder = kwargs
|
||||||
for key in kwargs.keys():
|
for key in kwargs_encoder.keys():
|
||||||
if key.startswith('decoder_'):
|
if key.startswith('decoder_'):
|
||||||
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
|
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
|
||||||
|
|
||||||
# Load and initialize the decoder
|
# Load and initialize the encoder and decoder
|
||||||
if encoder_model:
|
# The distinction between encoder and decoder at the model level is made
|
||||||
encoder = encoder_model
|
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||||
else:
|
encoder = kwargs.pop('encoder_model', None)
|
||||||
# Load and initialize the encoder
|
if encoder is None:
|
||||||
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
|
kwargs_encoder['is_decoder'] = False
|
||||||
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||||
|
|
||||||
# Load and initialize the decoder
|
decoder = kwargs.pop('decoder_model', None)
|
||||||
if decoder_model:
|
if decoder is None:
|
||||||
decoder = decoder_model
|
kwargs_decoder['is_decoder'] = True
|
||||||
else:
|
decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
|
||||||
kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
|
|
||||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
|
||||||
model = cls(encoder, decoder)
|
model = cls(encoder, decoder)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, *inputs, *kwargs):
|
def forward(self, *inputs, **kwargs):
|
||||||
# Extract decoder inputs
|
# Extract decoder inputs
|
||||||
decoder_kwargs = {}
|
decoder_kwargs = {}
|
||||||
for key in kwargs.keys():
|
for key in kwargs.keys():
|
||||||
|
|||||||
Reference in New Issue
Block a user