417 lines
21 KiB
Python
417 lines
21 KiB
Python
# coding=utf-8
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Copyright (c) HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch MMBT model. """
|
|
|
|
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import CrossEntropyLoss, MSELoss
|
|
|
|
from .file_utils import add_start_docstrings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModalEmbeddings(nn.Module):
|
|
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding.
|
|
"""
|
|
|
|
def __init__(self, config, encoder, embeddings):
|
|
super(ModalEmbeddings, self).__init__()
|
|
self.config = config
|
|
self.encoder = encoder
|
|
self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size)
|
|
self.position_embeddings = embeddings.position_embeddings
|
|
self.token_type_embeddings = embeddings.token_type_embeddings
|
|
self.word_embeddings = embeddings.word_embeddings
|
|
self.LayerNorm = embeddings.LayerNorm
|
|
self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
|
|
|
|
def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None):
|
|
token_embeddings = self.proj_embeddings(self.encoder(input_modal))
|
|
seq_length = token_embeddings.size(1)
|
|
|
|
if start_token is not None:
|
|
start_token_embeds = self.word_embeddings(start_token)
|
|
seq_length += 1
|
|
token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1)
|
|
|
|
if end_token is not None:
|
|
end_token_embeds = self.word_embeddings(end_token)
|
|
seq_length += 1
|
|
token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1)
|
|
|
|
if position_ids is None:
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device)
|
|
position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length)
|
|
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros(
|
|
(input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device
|
|
)
|
|
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
embeddings = token_embeddings + position_embeddings + token_type_embeddings
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
MMBT_START_DOCSTRING = r""" MMBT model was proposed in
|
|
`Supervised Multimodal Bitransformers for Classifying Images and Text`_
|
|
by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
|
|
It's a supervised multimodal bitransformer model that fuses information from text and other image encoders,
|
|
and obtain state-of-the-art performance on various multimodal classification benchmark tasks.
|
|
|
|
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
|
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
|
|
|
.. _`Supervised Multimodal Bitransformers for Classifying Images and Text`:
|
|
https://www.github.com/salesforce/ctrl
|
|
|
|
.. _`torch.nn.Module`:
|
|
https://pytorch.org/docs/stable/nn.html#module
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.MMBTConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
transformer (:class: `~nn.Module`): A text transformer that is used by MMBT.
|
|
It should have embeddings, encoder, and pooler attributes.
|
|
encoder (:class: `~nn.Module`): Encoder for the second modality.
|
|
It should take in a batch of modal inputs and return k, n dimension embeddings.
|
|
"""
|
|
|
|
MMBT_INPUTS_DOCSTRING = r""" Inputs:
|
|
**input_modal**: ``torch.FloatTensor`` of shape ``(batch_size, ***)``:
|
|
The other modality data. It will be the shape that the encoder for that type expects.
|
|
e.g. With an Image Encoder, the shape would be (batch_size, channels, height, width)
|
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of input sequence tokens in the vocabulary.
|
|
It does not expect [CLS] token to be added as it's appended to the end of other modality embeddings.
|
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
**modal_start_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
|
Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for Classification tasks.
|
|
**modal_end_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
|
Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used.
|
|
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Segment token indices to indicate different portions of the inputs.
|
|
**modal_token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
|
|
Segment token indices to indicate different portions of the non-text modality.
|
|
The embeddings from these tokens will be summed with the respective token embeddings for the non-text modality.
|
|
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
**modal_position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings for the non-text modality.
|
|
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
|
|
is configured as a decoder.
|
|
**encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
|
is used in the cross-attention if the model is configured as a decoder.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare MMBT Model outputting raw hidden-states without any specific head on top.",
|
|
MMBT_START_DOCSTRING,
|
|
MMBT_INPUTS_DOCSTRING,
|
|
)
|
|
class MMBTModel(nn.Module):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
|
Last layer hidden-state of the first token of the sequence (classification token)
|
|
further processed by a Linear layer and a Tanh activation function. The Linear
|
|
layer weights are trained from the next sentence prediction (classification)
|
|
objective during Bert pretraining. This output is usually *not* a good summary
|
|
of the semantic content of the input, you're often better with averaging or pooling
|
|
the sequence of hidden-states for the whole input sequence.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
transformer = BertModel.from_pretrained('bert-base-uncased')
|
|
encoder = ImageEncoder(args)
|
|
mmbt = MMBTModel(config, transformer, encoder)
|
|
"""
|
|
|
|
def __init__(self, config, transformer, encoder):
|
|
super(MMBTModel, self).__init__()
|
|
self.config = config
|
|
self.transformer = transformer
|
|
self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings)
|
|
|
|
def forward(
|
|
self,
|
|
input_modal,
|
|
input_ids=None,
|
|
modal_start_tokens=None,
|
|
modal_end_tokens=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
modal_token_type_ids=None,
|
|
position_ids=None,
|
|
modal_position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
):
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_txt_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_txt_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
modal_embeddings = self.modal_encoder(
|
|
input_modal,
|
|
start_token=modal_start_tokens,
|
|
end_token=modal_end_tokens,
|
|
position_ids=modal_position_ids,
|
|
token_type_ids=modal_token_type_ids,
|
|
)
|
|
|
|
input_modal_shape = modal_embeddings.size()[:-1]
|
|
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device)
|
|
|
|
txt_embeddings = self.transformer.embeddings(
|
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
|
)
|
|
|
|
embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1)
|
|
|
|
input_shape = embedding_output.size()[:-1]
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(input_shape, device=device)
|
|
else:
|
|
attention_mask = torch.cat(
|
|
[torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1
|
|
)
|
|
|
|
if encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(input_shape, device=device)
|
|
else:
|
|
encoder_attention_mask = torch.cat(
|
|
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
|
|
)
|
|
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
if attention_mask.dim() == 3:
|
|
extended_attention_mask = attention_mask[:, None, :, :]
|
|
|
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if attention_mask.dim() == 2:
|
|
if self.config.is_decoder:
|
|
batch_size, seq_length = input_shape
|
|
seq_ids = torch.arange(seq_length, device=device)
|
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
else:
|
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
|
if encoder_attention_mask.dim() == 3:
|
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
|
if encoder_attention_mask.dim() == 2:
|
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
|
|
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
|
dtype=next(self.parameters()).dtype
|
|
) # fp16 compatibility
|
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
if head_mask is not None:
|
|
if head_mask.dim() == 1:
|
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
elif head_mask.dim() == 2:
|
|
head_mask = (
|
|
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
|
) # We can specify head_mask for each layer
|
|
head_mask = head_mask.to(
|
|
dtype=next(self.parameters()).dtype
|
|
) # switch to fload if need + fp16 compatibility
|
|
else:
|
|
head_mask = [None] * self.config.num_hidden_layers
|
|
|
|
encoder_outputs = self.transformer.encoder(
|
|
embedding_output,
|
|
attention_mask=extended_attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
)
|
|
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.transformer.pooler(sequence_output)
|
|
|
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
|
1:
|
|
] # add hidden_states and attentions if they are here
|
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""MMBT Model with a sequence classification/regression head on top (a linear layer on top of
|
|
the pooled output)""",
|
|
MMBT_START_DOCSTRING,
|
|
MMBT_INPUTS_DOCSTRING,
|
|
)
|
|
class MMBTForClassification(nn.Module):
|
|
r"""
|
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
|
Labels for computing the sequence classification/regression loss.
|
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
|
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
|
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
|
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
Classification (or regression if config.num_labels==1) loss.
|
|
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
transformer = BertModel.from_pretrained('bert-base-uncased')
|
|
encoder = ImageEncoder(args)
|
|
model = MMBTForClassification(config, transformer, encoder)
|
|
outputs = model(input_modal, input_ids, labels=labels)
|
|
loss, logits = outputs[:2]
|
|
"""
|
|
|
|
def __init__(self, config, transformer, encoder):
|
|
super(MMBTForClassification, self).__init__()
|
|
self.num_labels = config.num_labels
|
|
|
|
self.mmbt = MMBTModel(config, transformer, encoder)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
def forward(
|
|
self,
|
|
input_modal,
|
|
input_ids=None,
|
|
modal_start_tokens=None,
|
|
modal_end_tokens=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
modal_token_type_ids=None,
|
|
position_ids=None,
|
|
modal_position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
labels=None,
|
|
):
|
|
|
|
outputs = self.mmbt(
|
|
input_modal=input_modal,
|
|
input_ids=input_ids,
|
|
modal_start_tokens=modal_start_tokens,
|
|
modal_end_tokens=modal_end_tokens,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
modal_token_type_ids=modal_token_type_ids,
|
|
position_ids=position_ids,
|
|
modal_position_ids=modal_position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
if labels is not None:
|
|
if self.num_labels == 1:
|
|
# We are doing regression
|
|
loss_fct = MSELoss()
|
|
loss = loss_fct(logits.view(-1), labels.view(-1))
|
|
else:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
outputs = (loss,) + outputs
|
|
|
|
return outputs # (loss), logits, (hidden_states), (attentions)
|