From 9b71fc9a18bbd49a699a338abe1891320c818108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 16 Oct 2019 21:31:38 +0200 Subject: [PATCH] tying weights is going to be a clusterfuck --- transformers/modeling_seq2seq.py | 81 ++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 25 deletions(-) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index 8f27224a56..4e76a1b8e7 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary logger = logging.getLogger(__name__) -class PreTrainedSeq2seq(nn.Module): +class PreTrainedSeq2seq(PreTrainedModel): r""" :class:`~transformers.Seq2seq` is a generic model class that will be instantiated as a Seq2seq model with one of the base model classes of @@ -36,13 +36,20 @@ class PreTrainedSeq2seq(nn.Module): the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class method. """ + def __init__(self, encoder, decoder): super(PreTrainedSeq2seq, self).__init__() self.encoder = encoder self.decoder = decoder @classmethod - def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs): + def from_pretrained( + cls, + encoder_pretrained_model_name_or_path, + decoder_pretrained_model_name_or_path, + *model_args, + **kwargs + ): r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints. @@ -110,21 +117,25 @@ class PreTrainedSeq2seq(nn.Module): kwargs_decoder = {} kwargs_encoder = kwargs for key in kwargs_encoder.keys(): - if key.startswith('decoder_'): - kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key) + if key.startswith("decoder_"): + kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key) # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs.pop('encoder_model', None) + encoder = kwargs.pop("encoder_model", None) if encoder is None: - kwargs_encoder['is_decoder'] = False - encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + kwargs_encoder["is_decoder"] = False + encoder = AutoModel.from_pretrained( + encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder + ) - decoder = kwargs.pop('decoder_model', None) + decoder = kwargs.pop("decoder_model", None) if decoder is None: - kwargs_decoder['is_decoder'] = True - decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + kwargs_decoder["is_decoder"] = True + decoder = AutoModelWithLMHead.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder + ) model = cls(encoder, decoder) @@ -153,11 +164,11 @@ class PreTrainedSeq2seq(nn.Module): kwargs_decoder = {} kwargs_encoder = kwargs for key in kwargs_encoder.keys(): - if key.startswith('decoder_'): - kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key) + if key.startswith("decoder_"): + kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key) # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None) + encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None) if encoder_hidden_states is None: encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_hidden_states = encoder_outputs[0] @@ -165,29 +176,49 @@ class PreTrainedSeq2seq(nn.Module): encoder_outputs = () # Decode - kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) return decoder_outputs + encoder_outputs class Model2Model(PreTrainedSeq2seq): - def tie_weights(): - # We should tie encoder and decoder embeddings if possible here - pass + def __init__(self): + super(Model2Model, self).__init__() + self.tie_weights() + + def tie_weights(self): + """ Tying the encoder and decoders' embeddings together. + + We need for each to get down to the embedding weights. However the + different model classes are inconsistent to that respect: + - BertModel: embeddings.word_embeddings + - RoBERTa: embeddings.word_embeddings + - XLMModel: embeddings + - GPT2: wte + - BertForMaskedLM: bert.embeddings.word_embeddings + - RobertaForMaskedLM: roberta.embeddings.word_embeddings + + argument of the XEmbedding layer for each model, but it is "blocked" + by a model-specific keyword (bert, )... + """ + # self._tie_or_clone_weights(self.encoder, self.decoder) + raise NotImplementedError class Model2LSTM(PreTrainedSeq2seq): @classmethod def from_pretrained(cls, *args, **kwargs): - if kwargs.get('decoder_model', None) is None: + if kwargs.get("decoder_model", None) is None: # We will create a randomly initilized LSTM model as decoder - if 'decoder_config' not in kwargs: - raise ValueError("To load an LSTM in Seq2seq model, please supply either: " - " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or" - " - a dictionary of configuration parameters that will be used to initialize a" - " torch.nn.LSTM model as `decoder_config` keyword argument. " - " E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`") - kwargs['decoder_model'] = torch.nn.LSTM(kwargs.pop('decoder_config')) + if "decoder_config" not in kwargs: + raise ValueError( + "To load an LSTM in Seq2seq model, please supply either: " + " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or" + " - a dictionary of configuration parameters that will be used to initialize a" + " torch.nn.LSTM model as `decoder_config` keyword argument. " + " E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`" + ) + kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config")) model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs) return model