Cleaning up seq2seq [WIP]
This commit is contained in:
@@ -199,12 +199,14 @@ class BertSelfAttention(nn.Module):
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
# if the attention Module is a encoder-decoder self attention module
|
||||
if encoder_hidden_states is not None:
|
||||
mixed_key_layer = self.key(encoder_hidden_states)
|
||||
mixed_value_layer = self.value(encoder_hidden_states)
|
||||
else:
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
if encoder_hidden_states is not None: # if encoder-decoder attention
|
||||
mixed_query_layer = self.query(encoder_hidden_states)
|
||||
else:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
@@ -322,26 +324,25 @@ class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.attention = BertAttention(config)
|
||||
if getattr(config, "is_decoder", False):
|
||||
self.is_decoder = config.is_decoder
|
||||
if self.is_decoder:
|
||||
self.crossattention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
if encoder_hidden_state is not None:
|
||||
try:
|
||||
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||
except AttributeError as ae:
|
||||
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
|
||||
raise
|
||||
if self.is_decoder and encoder_hidden_state is not None:
|
||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
outputs = (layer_output,) + outputs
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -352,14 +353,14 @@ class BertEncoder(nn.Module):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -377,42 +378,6 @@ class BertEncoder(nn.Module):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoder, self).__init__()
|
||||
config.is_decoder = True
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_state=encoder_outputs)
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPooler, self).__init__()
|
||||
@@ -635,7 +600,8 @@ class BertModel(BertPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||
head_mask=None, encoder_hidden_state=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
@@ -673,8 +639,9 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_state=encoder_hidden_state)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
@@ -682,111 +649,6 @@ class BertModel(BertPreTrainedModel):
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top.
|
||||
The model follows the general transformer decoder architecture.""",
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertDecoderModel(BertPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during Bert pretraining. This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertDecoderModel.from_pretrained('bert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertDecoderModel, self).__init__(config)
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.decoder = BertDecoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.decoder.layer[layer].attention.prune_heads(heads)
|
||||
self.decoder.layer[layer].crossattention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
decoder_outputs = self.decoder(embedding_output,
|
||||
encoder_outputs,
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
sequence_output = decoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||
BERT_START_DOCSTRING,
|
||||
@@ -1309,101 +1171,3 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class Bert2Rnd(BertPreTrainedModel):
|
||||
r"""
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = Bert2Rnd.from_pretrained('bert-base-uncased')
|
||||
# fine-tuning magic happens here
|
||||
input = tokenizer.encode("Hello, how are you?")
|
||||
outputs = model(input)
|
||||
output_text = tokenize.decode(outputs[0])
|
||||
print(output_text)
|
||||
|
||||
References::
|
||||
|
||||
[1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1
|
||||
[2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(Bert2Rnd, self).__init__(config)
|
||||
self.encoder = BertModel(config)
|
||||
self.decoder = BertDecoderModel(config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
|
||||
""" Load the pretrained weights in the encoder.
|
||||
|
||||
The encoder of `Bert2Rand` is initialized with pretrained weights; the
|
||||
weights of the decoder are initialized at random except the embeddings
|
||||
which are initialized with the pretrained embeddings. We thus need to override
|
||||
the base class' `from_pretrained` method.
|
||||
"""
|
||||
|
||||
# Load the configuration
|
||||
config = model_kwargs.pop('config', None)
|
||||
if config is None:
|
||||
cache_dir = model_kwargs.pop('cache_dir', None)
|
||||
force_download = model_kwargs.pop('force_download', False)
|
||||
config, _ = cls.config_class.from_pretrained(
|
||||
pretrained_model_or_path,
|
||||
*model_args,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
**model_kwargs
|
||||
)
|
||||
model = cls(config)
|
||||
|
||||
# We load the encoder with pretrained weights
|
||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
model.encoder = pretrained_encoder
|
||||
|
||||
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
|
||||
def randomize_decoder_weights(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
pretrained_decoder.apply(randomize_decoder_weights)
|
||||
model.decoder = pretrained_decoder
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
encoder_outputs = self.encoder(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
decoder_outputs = self.decoder(input_ids,
|
||||
encoder_outputs[0],
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
return decoder_outputs
|
||||
|
||||
249
transformers/modeling_seq2seq.py
Normal file
249
transformers/modeling_seq2seq.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Auto Model class. """
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering
|
||||
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel
|
||||
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel
|
||||
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
|
||||
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
|
||||
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
|
||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
|
||||
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.Seq2seq` is a generic model class
|
||||
that will be instantiated as a Seq2seq model with one of the base model classes of the library
|
||||
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||
- contains `bert`: BertModel (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetModel (XLNet model)
|
||||
- contains `xlm`: XLMModel (XLM model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
def __init__(self, encoder, decoder):
|
||||
super(PreTrainedSeq2seq, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r""" Instantiates one of the base model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||
- contains `bert`: BertModel (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetModel (XLNet model)
|
||||
- contains `xlm`: XLMModel (XLM model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
||||
|
||||
Examples::
|
||||
|
||||
model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||
model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
||||
assert model.config.output_attention == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
# Extract encoder and decoder model if provided
|
||||
encoder_model = kwargs.pop('encoder_model', None)
|
||||
decoder_model = kwargs.pop('decoder_model', None)
|
||||
|
||||
# Extract decoder kwargs so we only have encoder kwargs for now
|
||||
if decoder_model is None:
|
||||
decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path)
|
||||
decoder_kwargs = {}
|
||||
for key in kwargs.keys():
|
||||
if key.startswith('decoder_'):
|
||||
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
|
||||
|
||||
# Load and initialize the decoder
|
||||
if encoder_model:
|
||||
encoder = encoder_model
|
||||
else:
|
||||
# Load and initialize the encoder
|
||||
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
encoder = DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
encoder = RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
encoder = BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
||||
encoder = OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'gpt2' in pretrained_model_name_or_path:
|
||||
encoder = GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
||||
encoder = TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
encoder = XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
encoder = XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
|
||||
|
||||
# Load and initialize the decoder
|
||||
if decoder_model:
|
||||
decoder = decoder_model
|
||||
else:
|
||||
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
||||
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
|
||||
if 'distilbert' in decoder_pretrained_model_name_or_path:
|
||||
decoder = DistilBertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'roberta' in decoder_pretrained_model_name_or_path:
|
||||
decoder = RobertaModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'bert' in decoder_pretrained_model_name_or_path:
|
||||
decoder = BertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'openai-gpt' in decoder_pretrained_model_name_or_path:
|
||||
decoder = OpenAIGPTModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'gpt2' in decoder_pretrained_model_name_or_path:
|
||||
decoder = GPT2Model.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'transfo-xl' in decoder_pretrained_model_name_or_path:
|
||||
decoder = TransfoXLModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'xlnet' in decoder_pretrained_model_name_or_path:
|
||||
decoder = XLNetModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
elif 'xlm' in decoder_pretrained_model_name_or_path:
|
||||
decoder = XLMModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
|
||||
|
||||
model = cls(encoder, decoder)
|
||||
return model
|
||||
|
||||
def forward(self, *inputs, *kwargs):
|
||||
# Extract decoder inputs
|
||||
decoder_kwargs = {}
|
||||
for key in kwargs.keys():
|
||||
if key.startswith('decoder_'):
|
||||
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
|
||||
|
||||
# Compute encoder hidden states if needed
|
||||
encoder_hidden_states = kwargs.pop('encoder_hidden_states', None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(*inputs, *kwargs)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# Decode
|
||||
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
|
||||
decoder_outputs = self.decoder(**decoder_kwargs)
|
||||
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
def tie_weights():
|
||||
# We should tie encoder and decoder embeddings if possible here
|
||||
pass
|
||||
|
||||
|
||||
class Model2LSTM(PreTrainedSeq2seq):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
if kwargs.get('decoder_model', None) is None:
|
||||
# We will create a randomly initilized LSTM model as decoder
|
||||
if 'decoder_config' not in kwargs:
|
||||
raise ValueError("To load an LSTM in Seq2seq model, please supply either: "
|
||||
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or "
|
||||
" - a dictionary of configuration parameters that will be used to initialize a
|
||||
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
||||
" E.g. `decoder_config=\{'input_size': 768, 'hidden_size': 768, 'num_layers': 2\}`")
|
||||
kwargs['decoder_model'] = torch.nn.LSTM(kwarg.pop('decoder_config'))
|
||||
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user