Cleaning up seq2seq [WIP]
This commit is contained in:
@@ -199,12 +199,14 @@ class BertSelfAttention(nn.Module):
|
|||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||||
mixed_key_layer = self.key(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
mixed_value_layer = self.value(hidden_states)
|
# if the attention Module is a encoder-decoder self attention module
|
||||||
if encoder_hidden_states is not None: # if encoder-decoder attention
|
if encoder_hidden_states is not None:
|
||||||
mixed_query_layer = self.query(encoder_hidden_states)
|
mixed_key_layer = self.key(encoder_hidden_states)
|
||||||
|
mixed_value_layer = self.value(encoder_hidden_states)
|
||||||
else:
|
else:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_key_layer = self.key(hidden_states)
|
||||||
|
mixed_value_layer = self.value(hidden_states)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||||
@@ -322,26 +324,25 @@ class BertLayer(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertLayer, self).__init__()
|
super(BertLayer, self).__init__()
|
||||||
self.attention = BertAttention(config)
|
self.attention = BertAttention(config)
|
||||||
if getattr(config, "is_decoder", False):
|
self.is_decoder = config.is_decoder
|
||||||
|
if self.is_decoder:
|
||||||
self.crossattention = BertAttention(config)
|
self.crossattention = BertAttention(config)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
||||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
if encoder_hidden_state is not None:
|
if self.is_decoder and encoder_hidden_state is not None:
|
||||||
try:
|
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||||
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
attention_output = cross_attention_outputs[0]
|
||||||
except AttributeError as ae:
|
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||||
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
|
|
||||||
raise
|
|
||||||
|
|
||||||
attention_output = attention_outputs[0]
|
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + outputs
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -352,14 +353,14 @@ class BertEncoder(nn.Module):
|
|||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
@@ -377,42 +378,6 @@ class BertEncoder(nn.Module):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
class BertDecoder(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super(BertDecoder, self).__init__()
|
|
||||||
config.is_decoder = True
|
|
||||||
self.output_attentions = config.output_attentions
|
|
||||||
self.output_hidden_states = config.output_hidden_states
|
|
||||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
|
||||||
all_hidden_states = ()
|
|
||||||
all_attentions = ()
|
|
||||||
for i, layer_module in enumerate(self.layer):
|
|
||||||
if self.output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
layer_outputs = layer_module(hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask[i],
|
|
||||||
encoder_hidden_state=encoder_outputs)
|
|
||||||
if self.output_attentions:
|
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
# Add last layer
|
|
||||||
if self.output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
if self.output_hidden_states:
|
|
||||||
outputs = outputs + (all_hidden_states,)
|
|
||||||
if self.output_attentions:
|
|
||||||
outputs = outputs + (all_attentions,)
|
|
||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
|
||||||
|
|
||||||
|
|
||||||
class BertPooler(nn.Module):
|
class BertPooler(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertPooler, self).__init__()
|
super(BertPooler, self).__init__()
|
||||||
@@ -635,7 +600,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||||
|
head_mask=None, encoder_hidden_state=None):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
@@ -673,8 +639,9 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||||
encoder_outputs = self.encoder(embedding_output,
|
encoder_outputs = self.encoder(embedding_output,
|
||||||
extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_state=encoder_hidden_state)
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
|
||||||
@@ -682,111 +649,6 @@ class BertModel(BertPreTrainedModel):
|
|||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top.
|
|
||||||
The model follows the general transformer decoder architecture.""",
|
|
||||||
BERT_START_DOCSTRING,
|
|
||||||
BERT_INPUTS_DOCSTRING)
|
|
||||||
class BertDecoderModel(BertPreTrainedModel):
|
|
||||||
r"""
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
||||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
|
||||||
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
|
||||||
Last layer hidden-state of the first token of the sequence (classification token)
|
|
||||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
|
||||||
layer weights are trained from the next sentence prediction (classification)
|
|
||||||
objective during Bert pretraining. This output is usually *not* a good summary
|
|
||||||
of the semantic content of the input, you're often better with averaging or pooling
|
|
||||||
the sequence of hidden-states for the whole input sequence.
|
|
||||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
||||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
||||||
model = BertDecoderModel.from_pretrained('bert-base-uncased')
|
|
||||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
|
||||||
outputs = model(input_ids)
|
|
||||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self, config):
|
|
||||||
super(BertDecoderModel, self).__init__(config)
|
|
||||||
|
|
||||||
self.embeddings = BertEmbeddings(config)
|
|
||||||
self.decoder = BertDecoder(config)
|
|
||||||
self.pooler = BertPooler(config)
|
|
||||||
|
|
||||||
self.init_weights()
|
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
|
||||||
old_embeddings = self.embeddings.word_embeddings
|
|
||||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
|
||||||
self.embeddings.word_embeddings = new_embeddings
|
|
||||||
return self.embeddings.word_embeddings
|
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
|
||||||
""" Prunes heads of the model.
|
|
||||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
||||||
See base class PreTrainedModel
|
|
||||||
"""
|
|
||||||
for layer, heads in heads_to_prune.items():
|
|
||||||
self.decoder.layer[layer].attention.prune_heads(heads)
|
|
||||||
self.decoder.layer[layer].crossattention.prune_heads(heads)
|
|
||||||
|
|
||||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
if token_type_ids is None:
|
|
||||||
token_type_ids = torch.zeros_like(input_ids)
|
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
|
||||||
decoder_outputs = self.decoder(embedding_output,
|
|
||||||
encoder_outputs,
|
|
||||||
extended_attention_mask,
|
|
||||||
head_mask=head_mask)
|
|
||||||
sequence_output = decoder_outputs[0]
|
|
||||||
pooled_output = self.pooler(sequence_output)
|
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
|
|
||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
@@ -1309,101 +1171,3 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
outputs = (total_loss,) + outputs
|
outputs = (total_loss,) + outputs
|
||||||
|
|
||||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
|
|
||||||
BERT_START_DOCSTRING,
|
|
||||||
BERT_INPUTS_DOCSTRING)
|
|
||||||
class Bert2Rnd(BertPreTrainedModel):
|
|
||||||
r"""
|
|
||||||
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
||||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
||||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
||||||
model = Bert2Rnd.from_pretrained('bert-base-uncased')
|
|
||||||
# fine-tuning magic happens here
|
|
||||||
input = tokenizer.encode("Hello, how are you?")
|
|
||||||
outputs = model(input)
|
|
||||||
output_text = tokenize.decode(outputs[0])
|
|
||||||
print(output_text)
|
|
||||||
|
|
||||||
References::
|
|
||||||
|
|
||||||
[1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1
|
|
||||||
[2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super(Bert2Rnd, self).__init__(config)
|
|
||||||
self.encoder = BertModel(config)
|
|
||||||
self.decoder = BertDecoderModel(config)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
|
|
||||||
""" Load the pretrained weights in the encoder.
|
|
||||||
|
|
||||||
The encoder of `Bert2Rand` is initialized with pretrained weights; the
|
|
||||||
weights of the decoder are initialized at random except the embeddings
|
|
||||||
which are initialized with the pretrained embeddings. We thus need to override
|
|
||||||
the base class' `from_pretrained` method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Load the configuration
|
|
||||||
config = model_kwargs.pop('config', None)
|
|
||||||
if config is None:
|
|
||||||
cache_dir = model_kwargs.pop('cache_dir', None)
|
|
||||||
force_download = model_kwargs.pop('force_download', False)
|
|
||||||
config, _ = cls.config_class.from_pretrained(
|
|
||||||
pretrained_model_or_path,
|
|
||||||
*model_args,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
return_unused_kwargs=True,
|
|
||||||
force_download=force_download,
|
|
||||||
**model_kwargs
|
|
||||||
)
|
|
||||||
model = cls(config)
|
|
||||||
|
|
||||||
# We load the encoder with pretrained weights
|
|
||||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
|
||||||
model.encoder = pretrained_encoder
|
|
||||||
|
|
||||||
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
|
|
||||||
def randomize_decoder_weights(module):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
||||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
|
||||||
elif isinstance(module, BertLayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
|
|
||||||
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
|
||||||
pretrained_decoder.apply(randomize_decoder_weights)
|
|
||||||
model.decoder = pretrained_decoder
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
|
||||||
encoder_outputs = self.encoder(input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask)
|
|
||||||
decoder_outputs = self.decoder(input_ids,
|
|
||||||
encoder_outputs[0],
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask)
|
|
||||||
return decoder_outputs
|
|
||||||
|
|||||||
249
transformers/modeling_seq2seq.py
Normal file
249
transformers/modeling_seq2seq.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Auto Model class. """
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering
|
||||||
|
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel
|
||||||
|
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel
|
||||||
|
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
|
||||||
|
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
|
||||||
|
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
|
||||||
|
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||||
|
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
|
||||||
|
|
||||||
|
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
|
||||||
|
from .file_utils import add_start_docstrings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PreTrainedSeq2seq(nn.Module):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.Seq2seq` is a generic model class
|
||||||
|
that will be instantiated as a Seq2seq model with one of the base model classes of the library
|
||||||
|
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||||
|
class method.
|
||||||
|
|
||||||
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
|
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||||
|
|
||||||
|
The base model class to instantiate is selected as the first pattern matching
|
||||||
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
|
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||||
|
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||||
|
- contains `bert`: BertModel (Bert model)
|
||||||
|
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||||
|
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||||
|
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||||
|
- contains `xlnet`: XLNetModel (XLNet model)
|
||||||
|
- contains `xlm`: XLMModel (XLM model)
|
||||||
|
|
||||||
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
|
"""
|
||||||
|
def __init__(self, encoder, decoder):
|
||||||
|
super(PreTrainedSeq2seq, self).__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r""" Instantiates one of the base model classes of the library
|
||||||
|
from a pre-trained model configuration.
|
||||||
|
|
||||||
|
The model class to instantiate is selected as the first pattern matching
|
||||||
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
|
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||||
|
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||||
|
- contains `bert`: BertModel (Bert model)
|
||||||
|
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||||
|
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||||
|
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||||
|
- contains `xlnet`: XLNetModel (XLNet model)
|
||||||
|
- contains `xlm`: XLMModel (XLM model)
|
||||||
|
|
||||||
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
|
|
||||||
|
Params:
|
||||||
|
pretrained_model_name_or_path: either:
|
||||||
|
|
||||||
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
|
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
|
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
|
||||||
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
|
||||||
|
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
|
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||||
|
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||||
|
|
||||||
|
state_dict: (`optional`) dict:
|
||||||
|
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||||
|
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||||
|
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||||
|
|
||||||
|
cache_dir: (`optional`) string:
|
||||||
|
Path to a directory in which a downloaded pre-trained model
|
||||||
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
|
output_loading_info: (`optional`) boolean:
|
||||||
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
|
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||||
|
|
||||||
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
||||||
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||||
|
model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
||||||
|
assert model.config.output_attention == True
|
||||||
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||||
|
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Extract encoder and decoder model if provided
|
||||||
|
encoder_model = kwargs.pop('encoder_model', None)
|
||||||
|
decoder_model = kwargs.pop('decoder_model', None)
|
||||||
|
|
||||||
|
# Extract decoder kwargs so we only have encoder kwargs for now
|
||||||
|
if decoder_model is None:
|
||||||
|
decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path)
|
||||||
|
decoder_kwargs = {}
|
||||||
|
for key in kwargs.keys():
|
||||||
|
if key.startswith('decoder_'):
|
||||||
|
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
|
||||||
|
|
||||||
|
# Load and initialize the decoder
|
||||||
|
if encoder_model:
|
||||||
|
encoder = encoder_model
|
||||||
|
else:
|
||||||
|
# Load and initialize the encoder
|
||||||
|
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
|
||||||
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
|
encoder = DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
|
encoder = RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'bert' in pretrained_model_name_or_path:
|
||||||
|
encoder = BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'openai-gpt' in pretrained_model_name_or_path:
|
||||||
|
encoder = OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'gpt2' in pretrained_model_name_or_path:
|
||||||
|
encoder = GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'transfo-xl' in pretrained_model_name_or_path:
|
||||||
|
encoder = TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'xlnet' in pretrained_model_name_or_path:
|
||||||
|
encoder = XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'xlm' in pretrained_model_name_or_path:
|
||||||
|
encoder = XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
|
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
|
||||||
|
|
||||||
|
# Load and initialize the decoder
|
||||||
|
if decoder_model:
|
||||||
|
decoder = decoder_model
|
||||||
|
else:
|
||||||
|
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
||||||
|
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
|
||||||
|
if 'distilbert' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = DistilBertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'roberta' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = RobertaModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'bert' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = BertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'openai-gpt' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = OpenAIGPTModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'gpt2' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = GPT2Model.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'transfo-xl' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = TransfoXLModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'xlnet' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = XLNetModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'xlm' in decoder_pretrained_model_name_or_path:
|
||||||
|
decoder = XLMModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
|
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
|
||||||
|
|
||||||
|
model = cls(encoder, decoder)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(self, *inputs, *kwargs):
|
||||||
|
# Extract decoder inputs
|
||||||
|
decoder_kwargs = {}
|
||||||
|
for key in kwargs.keys():
|
||||||
|
if key.startswith('decoder_'):
|
||||||
|
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
|
||||||
|
|
||||||
|
# Compute encoder hidden states if needed
|
||||||
|
encoder_hidden_states = kwargs.pop('encoder_hidden_states', None)
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_outputs = self.encoder(*inputs, *kwargs)
|
||||||
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
|
||||||
|
decoder_outputs = self.decoder(**decoder_kwargs)
|
||||||
|
|
||||||
|
return decoder_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Model2Model(PreTrainedSeq2seq):
|
||||||
|
def tie_weights():
|
||||||
|
# We should tie encoder and decoder embeddings if possible here
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Model2LSTM(PreTrainedSeq2seq):
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
if kwargs.get('decoder_model', None) is None:
|
||||||
|
# We will create a randomly initilized LSTM model as decoder
|
||||||
|
if 'decoder_config' not in kwargs:
|
||||||
|
raise ValueError("To load an LSTM in Seq2seq model, please supply either: "
|
||||||
|
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or "
|
||||||
|
" - a dictionary of configuration parameters that will be used to initialize a
|
||||||
|
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
||||||
|
" E.g. `decoder_config=\{'input_size': 768, 'hidden_size': 768, 'num_layers': 2\}`")
|
||||||
|
kwargs['decoder_model'] = torch.nn.LSTM(kwarg.pop('decoder_config'))
|
||||||
|
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
Reference in New Issue
Block a user