From 0ef9bc923a3fa3f12d39a516aec2069e9ffc4e6e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 14 Oct 2019 11:58:13 +0200 Subject: [PATCH] Cleaning up seq2seq [WIP] --- transformers/modeling_bert.py | 284 +++---------------------------- transformers/modeling_seq2seq.py | 249 +++++++++++++++++++++++++++ 2 files changed, 273 insertions(+), 260 deletions(-) create mode 100644 transformers/modeling_seq2seq.py diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 03559ad26c..fbf3c84646 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -199,12 +199,14 @@ class BertSelfAttention(nn.Module): return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - if encoder_hidden_states is not None: # if encoder-decoder attention - mixed_query_layer = self.query(encoder_hidden_states) + mixed_query_layer = self.query(hidden_states) + # if the attention Module is a encoder-decoder self attention module + if encoder_hidden_states is not None: + mixed_key_layer = self.key(encoder_hidden_states) + mixed_value_layer = self.value(encoder_hidden_states) else: - mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) @@ -322,26 +324,25 @@ class BertLayer(nn.Module): def __init__(self, config): super(BertLayer, self).__init__() self.attention = BertAttention(config) - if getattr(config, "is_decoder", False): + self.is_decoder = config.is_decoder + if self.is_decoder: self.crossattention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None): - attention_outputs = self.attention(hidden_states, attention_mask, head_mask) - attention_output = attention_outputs[0] + self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - if encoder_hidden_state is not None: - try: - attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state) - except AttributeError as ae: - print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae) - raise + if self.is_decoder and encoder_hidden_state is not None: + cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights - attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output,) + outputs return outputs @@ -352,14 +353,14 @@ class BertEncoder(nn.Module): self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) - def forward(self, hidden_states, attention_mask=None, head_mask=None): + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): all_hidden_states = () all_attentions = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states) hidden_states = layer_outputs[0] if self.output_attentions: @@ -377,42 +378,6 @@ class BertEncoder(nn.Module): return outputs # last-layer hidden state, (all hidden states), (all attentions) -class BertDecoder(nn.Module): - def __init__(self, config): - super(BertDecoder, self).__init__() - config.is_decoder = True - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) - - def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None): - all_hidden_states = () - all_attentions = () - for i, layer_module in enumerate(self.layer): - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module(hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_state=encoder_outputs) - if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - hidden_states = layer_outputs[0] - - # Add last layer - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = (hidden_states,) - if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) - if self.output_attentions: - outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) - - class BertPooler(nn.Module): def __init__(self, config): super(BertPooler, self).__init__() @@ -635,7 +600,8 @@ class BertModel(BertPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, + head_mask=None, encoder_hidden_state=None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: @@ -673,8 +639,9 @@ class BertModel(BertPreTrainedModel): embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) encoder_outputs = self.encoder(embedding_output, - extended_attention_mask, - head_mask=head_mask) + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_state=encoder_hidden_state) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) @@ -682,111 +649,6 @@ class BertModel(BertPreTrainedModel): return outputs # sequence_output, pooled_output, (hidden_states), (attentions) -@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top. - The model follows the general transformer decoder architecture.""", - BERT_START_DOCSTRING, - BERT_INPUTS_DOCSTRING) -class BertDecoderModel(BertPreTrainedModel): - r""" - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` - Sequence of hidden-states at the output of the last layer of the model. - **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` - Last layer hidden-state of the first token of the sequence (classification token) - further processed by a Linear layer and a Tanh activation function. The Linear - layer weights are trained from the next sentence prediction (classification) - objective during Bert pretraining. This output is usually *not* a good summary - of the semantic content of the input, you're often better with averaging or pooling - the sequence of hidden-states for the whole input sequence. - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - model = BertDecoderModel.from_pretrained('bert-base-uncased') - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - - """ - def __init__(self, config): - super(BertDecoderModel, self).__init__(config) - - self.embeddings = BertEmbeddings(config) - self.decoder = BertDecoder(config) - self.pooler = BertPooler(config) - - self.init_weights() - - def _resize_token_embeddings(self, new_num_tokens): - old_embeddings = self.embeddings.word_embeddings - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) - self.embeddings.word_embeddings = new_embeddings - return self.embeddings.word_embeddings - - def _prune_heads(self, heads_to_prune): - """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - See base class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.decoder.layer[layer].attention.prune_heads(heads) - self.decoder.layer[layer].crossattention.prune_heads(heads) - - def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer - head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers - - embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) - decoder_outputs = self.decoder(embedding_output, - encoder_outputs, - extended_attention_mask, - head_mask=head_mask) - sequence_output = decoder_outputs[0] - pooled_output = self.pooler(sequence_output) - - outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here - return outputs # sequence_output, pooled_output, (hidden_states), (attentions) - - @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and a `next sentence prediction (classification)` head. """, BERT_START_DOCSTRING, @@ -1309,101 +1171,3 @@ class BertForQuestionAnswering(BertPreTrainedModel): outputs = (total_loss,) + outputs return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) - - -@add_start_docstrings("Bert encoder-decoder model for sequence generation.", - BERT_START_DOCSTRING, - BERT_INPUTS_DOCSTRING) -class Bert2Rnd(BertPreTrainedModel): - r""" - - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - model = Bert2Rnd.from_pretrained('bert-base-uncased') - # fine-tuning magic happens here - input = tokenizer.encode("Hello, how are you?") - outputs = model(input) - output_text = tokenize.decode(outputs[0]) - print(output_text) - - References:: - - [1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1 - [2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor - - """ - - def __init__(self, config): - super(Bert2Rnd, self).__init__(config) - self.encoder = BertModel(config) - self.decoder = BertDecoderModel(config) - - @classmethod - def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs): - """ Load the pretrained weights in the encoder. - - The encoder of `Bert2Rand` is initialized with pretrained weights; the - weights of the decoder are initialized at random except the embeddings - which are initialized with the pretrained embeddings. We thus need to override - the base class' `from_pretrained` method. - """ - - # Load the configuration - config = model_kwargs.pop('config', None) - if config is None: - cache_dir = model_kwargs.pop('cache_dir', None) - force_download = model_kwargs.pop('force_download', False) - config, _ = cls.config_class.from_pretrained( - pretrained_model_or_path, - *model_args, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - **model_kwargs - ) - model = cls(config) - - # We load the encoder with pretrained weights - pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) - model.encoder = pretrained_encoder - - # We load the decoder with pretrained weights and then randomize all weights but embeddings-related one. - def randomize_decoder_weights(module): - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) - pretrained_decoder.apply(randomize_decoder_weights) - model.decoder = pretrained_decoder - - return model - - def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): - encoder_outputs = self.encoder(input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask) - decoder_outputs = self.decoder(input_ids, - encoder_outputs[0], - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask) - return decoder_outputs diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py new file mode 100644 index 0000000000..50891ddded --- /dev/null +++ b/transformers/modeling_seq2seq.py @@ -0,0 +1,249 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Auto Model class. """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging + +import torch +from torch import nn + +from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering +from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel +from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel +from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel +from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering +from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering +from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification +from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification + +from .modeling_utils import PreTrainedModel, SequenceSummary + +from .file_utils import add_start_docstrings + +logger = logging.getLogger(__name__) + + +class PreTrainedSeq2seq(nn.Module): + r""" + :class:`~transformers.Seq2seq` is a generic model class + that will be instantiated as a Seq2seq model with one of the base model classes of the library + as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` + class method. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The base model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertModel (DistilBERT model) + - contains `roberta`: RobertaModel (RoBERTa model) + - contains `bert`: BertModel (Bert model) + - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) + - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) + - contains `xlnet`: XLNetModel (XLNet model) + - contains `xlm`: XLMModel (XLM model) + + This class cannot be instantiated using `__init__()` (throws an error). + """ + def __init__(self, encoder, decoder): + super(PreTrainedSeq2seq, self).__init__() + self.encoder = encoder + self.decoder = decoder + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the base model classes of the library + from a pre-trained model configuration. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertModel (DistilBERT model) + - contains `roberta`: RobertaModel (RoBERTa model) + - contains `bert`: BertModel (Bert model) + - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) + - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) + - contains `xlnet`: XLNetModel (XLNet model) + - contains `xlm`: XLMModel (XLM model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Params: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') + model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + # Extract encoder and decoder model if provided + encoder_model = kwargs.pop('encoder_model', None) + decoder_model = kwargs.pop('decoder_model', None) + + # Extract decoder kwargs so we only have encoder kwargs for now + if decoder_model is None: + decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path) + decoder_kwargs = {} + for key in kwargs.keys(): + if key.startswith('decoder_'): + decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key) + + # Load and initialize the decoder + if encoder_model: + encoder = encoder_model + else: + # Load and initialize the encoder + kwargs['is_decoder'] = False # Make sure the encoder will be an encoder + if 'distilbert' in pretrained_model_name_or_path: + encoder = DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: + encoder = RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'bert' in pretrained_model_name_or_path: + encoder = BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'openai-gpt' in pretrained_model_name_or_path: + encoder = OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'gpt2' in pretrained_model_name_or_path: + encoder = GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'transfo-xl' in pretrained_model_name_or_path: + encoder = TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'xlnet' in pretrained_model_name_or_path: + encoder = XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'xlm' in pretrained_model_name_or_path: + encoder = XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + else: + raise ValueError("Unrecognized model identifier in {}. Should contains one of " + "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " + "'xlm', 'roberta'".format(pretrained_model_name_or_path)) + + # Load and initialize the decoder + if decoder_model: + decoder = decoder_model + else: + kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc... + kwargs['is_decoder'] = True # Make sure the decoder will be an decoder + if 'distilbert' in decoder_pretrained_model_name_or_path: + decoder = DistilBertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + elif 'roberta' in decoder_pretrained_model_name_or_path: + decoder = RobertaModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + elif 'bert' in decoder_pretrained_model_name_or_path: + decoder = BertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + elif 'openai-gpt' in decoder_pretrained_model_name_or_path: + decoder = OpenAIGPTModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + elif 'gpt2' in decoder_pretrained_model_name_or_path: + decoder = GPT2Model.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + elif 'transfo-xl' in decoder_pretrained_model_name_or_path: + decoder = TransfoXLModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + elif 'xlnet' in decoder_pretrained_model_name_or_path: + decoder = XLNetModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + elif 'xlm' in decoder_pretrained_model_name_or_path: + decoder = XLMModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + else: + raise ValueError("Unrecognized model identifier in {}. Should contains one of " + "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " + "'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path)) + + model = cls(encoder, decoder) + return model + + def forward(self, *inputs, *kwargs): + # Extract decoder inputs + decoder_kwargs = {} + for key in kwargs.keys(): + if key.startswith('decoder_'): + decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key) + + # Compute encoder hidden states if needed + encoder_hidden_states = kwargs.pop('encoder_hidden_states', None) + if encoder_hidden_states is None: + encoder_outputs = self.encoder(*inputs, *kwargs) + encoder_hidden_states = encoder_outputs[0] + + # Decode + decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs = self.decoder(**decoder_kwargs) + + return decoder_outputs + + +class Model2Model(PreTrainedSeq2seq): + def tie_weights(): + # We should tie encoder and decoder embeddings if possible here + pass + + +class Model2LSTM(PreTrainedSeq2seq): + @classmethod + def from_pretrained(cls, *args, **kwargs): + if kwargs.get('decoder_model', None) is None: + # We will create a randomly initilized LSTM model as decoder + if 'decoder_config' not in kwargs: + raise ValueError("To load an LSTM in Seq2seq model, please supply either: " + " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or " + " - a dictionary of configuration parameters that will be used to initialize a + " torch.nn.LSTM model as `decoder_config` keyword argument. " + " E.g. `decoder_config=\{'input_size': 768, 'hidden_size': 768, 'num_layers': 2\}`") + kwargs['decoder_model'] = torch.nn.LSTM(kwarg.pop('decoder_config')) + model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs) + return model +