tying weights is going to be a clusterfuck
This commit is contained in:
@@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedSeq2seq(nn.Module):
|
class PreTrainedSeq2seq(PreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
:class:`~transformers.Seq2seq` is a generic model class that will be
|
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||||
instantiated as a Seq2seq model with one of the base model classes of
|
instantiated as a Seq2seq model with one of the base model classes of
|
||||||
@@ -36,13 +36,20 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
|
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
|
||||||
method.
|
method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder, decoder):
|
def __init__(self, encoder, decoder):
|
||||||
super(PreTrainedSeq2seq, self).__init__()
|
super(PreTrainedSeq2seq, self).__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
encoder_pretrained_model_name_or_path,
|
||||||
|
decoder_pretrained_model_name_or_path,
|
||||||
|
*model_args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||||
of the library from pre-trained model checkpoints.
|
of the library from pre-trained model checkpoints.
|
||||||
|
|
||||||
@@ -110,21 +117,25 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
kwargs_decoder = {}
|
kwargs_decoder = {}
|
||||||
kwargs_encoder = kwargs
|
kwargs_encoder = kwargs
|
||||||
for key in kwargs_encoder.keys():
|
for key in kwargs_encoder.keys():
|
||||||
if key.startswith('decoder_'):
|
if key.startswith("decoder_"):
|
||||||
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
|
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
|
||||||
|
|
||||||
# Load and initialize the encoder and decoder
|
# Load and initialize the encoder and decoder
|
||||||
# The distinction between encoder and decoder at the model level is made
|
# The distinction between encoder and decoder at the model level is made
|
||||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||||
encoder = kwargs.pop('encoder_model', None)
|
encoder = kwargs.pop("encoder_model", None)
|
||||||
if encoder is None:
|
if encoder is None:
|
||||||
kwargs_encoder['is_decoder'] = False
|
kwargs_encoder["is_decoder"] = False
|
||||||
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
encoder = AutoModel.from_pretrained(
|
||||||
|
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
||||||
|
)
|
||||||
|
|
||||||
decoder = kwargs.pop('decoder_model', None)
|
decoder = kwargs.pop("decoder_model", None)
|
||||||
if decoder is None:
|
if decoder is None:
|
||||||
kwargs_decoder['is_decoder'] = True
|
kwargs_decoder["is_decoder"] = True
|
||||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = AutoModelWithLMHead.from_pretrained(
|
||||||
|
decoder_pretrained_model_name_or_path, **kwargs_decoder
|
||||||
|
)
|
||||||
|
|
||||||
model = cls(encoder, decoder)
|
model = cls(encoder, decoder)
|
||||||
|
|
||||||
@@ -153,11 +164,11 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
kwargs_decoder = {}
|
kwargs_decoder = {}
|
||||||
kwargs_encoder = kwargs
|
kwargs_encoder = kwargs
|
||||||
for key in kwargs_encoder.keys():
|
for key in kwargs_encoder.keys():
|
||||||
if key.startswith('decoder_'):
|
if key.startswith("decoder_"):
|
||||||
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
|
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
|
||||||
|
|
||||||
# Encode if needed (training, first prediction pass)
|
# Encode if needed (training, first prediction pass)
|
||||||
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
|
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
@@ -165,29 +176,49 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
|
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
|
||||||
class Model2Model(PreTrainedSeq2seq):
|
class Model2Model(PreTrainedSeq2seq):
|
||||||
def tie_weights():
|
def __init__(self):
|
||||||
# We should tie encoder and decoder embeddings if possible here
|
super(Model2Model, self).__init__()
|
||||||
pass
|
self.tie_weights()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
""" Tying the encoder and decoders' embeddings together.
|
||||||
|
|
||||||
|
We need for each to get down to the embedding weights. However the
|
||||||
|
different model classes are inconsistent to that respect:
|
||||||
|
- BertModel: embeddings.word_embeddings
|
||||||
|
- RoBERTa: embeddings.word_embeddings
|
||||||
|
- XLMModel: embeddings
|
||||||
|
- GPT2: wte
|
||||||
|
- BertForMaskedLM: bert.embeddings.word_embeddings
|
||||||
|
- RobertaForMaskedLM: roberta.embeddings.word_embeddings
|
||||||
|
|
||||||
|
argument of the XEmbedding layer for each model, but it is "blocked"
|
||||||
|
by a model-specific keyword (bert, )...
|
||||||
|
"""
|
||||||
|
# self._tie_or_clone_weights(self.encoder, self.decoder)
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Model2LSTM(PreTrainedSeq2seq):
|
class Model2LSTM(PreTrainedSeq2seq):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
if kwargs.get('decoder_model', None) is None:
|
if kwargs.get("decoder_model", None) is None:
|
||||||
# We will create a randomly initilized LSTM model as decoder
|
# We will create a randomly initilized LSTM model as decoder
|
||||||
if 'decoder_config' not in kwargs:
|
if "decoder_config" not in kwargs:
|
||||||
raise ValueError("To load an LSTM in Seq2seq model, please supply either: "
|
raise ValueError(
|
||||||
|
"To load an LSTM in Seq2seq model, please supply either: "
|
||||||
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
||||||
" - a dictionary of configuration parameters that will be used to initialize a"
|
" - a dictionary of configuration parameters that will be used to initialize a"
|
||||||
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
||||||
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`")
|
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
|
||||||
kwargs['decoder_model'] = torch.nn.LSTM(kwargs.pop('decoder_config'))
|
)
|
||||||
|
kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config"))
|
||||||
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
|
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user