1329 lines
58 KiB
Python
1329 lines
58 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch BART model, ported from the fairseq repo."""
|
|
import logging
|
|
import math
|
|
import random
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
|
|
from .configuration_bart import BartConfig
|
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
|
from .modeling_utils import BeamHypotheses, PreTrainedModel, create_position_ids_from_input_ids
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
|
|
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
|
|
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
|
|
}
|
|
|
|
BART_START_DOCSTRING = r"""
|
|
|
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
|
|
refer to the PyTorch documentation for all matters related to general usage and behavior.
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
|
|
"""
|
|
|
|
BART_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
|
|
Padding will be ignored by default should you provide it.
|
|
Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
|
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Mask to avoid performing attention on padding token indices in input_ids.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
|
|
decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, 1, tgt_seq_len, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
|
Default behavior: generate a tensor that ignores pad tokens and future tokens, as in the paper.
|
|
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
|
See diagram 1 in the paper for more info on the default strategy
|
|
"""
|
|
LARGE_NEGATIVE = -1e4
|
|
|
|
|
|
def _prepare_bart_decoder_inputs(
|
|
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None,
|
|
):
|
|
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
|
|
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
|
|
"""
|
|
pad_token_id = config.pad_token_id
|
|
need_causal_mask = not config.output_past
|
|
if decoder_input_ids is None:
|
|
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
|
|
bsz, tgt_len = decoder_input_ids.size()[:2]
|
|
if decoder_attn_mask is None:
|
|
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
|
if need_causal_mask:
|
|
causal_lm_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1)
|
|
else:
|
|
causal_lm_mask = None
|
|
new_shape = (bsz, tgt_len, tgt_len)
|
|
# make it broadcastable so can just be added to the attention coefficients
|
|
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
|
|
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
|
|
return decoder_input_ids, decoder_attn_mask
|
|
|
|
|
|
class PretrainedBartModel(PreTrainedModel):
|
|
config_class = BartConfig
|
|
base_model_prefix = "model"
|
|
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
|
|
def _init_weights(self, module):
|
|
std = self.config.init_std
|
|
|
|
# called init_bert_params in fairseq
|
|
if isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
if isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
|
|
@property
|
|
def dummy_inputs(self):
|
|
pad_token = 1
|
|
input_ids = torch.Tensor(
|
|
[
|
|
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2],
|
|
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 2, pad_token],
|
|
]
|
|
).long()
|
|
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
|
|
self.config, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attn_mask=None
|
|
)
|
|
dummy_inputs = {
|
|
"decoder_input_ids": decoder_input_ids,
|
|
"attention_mask": input_ids.ne(pad_token),
|
|
"input_ids": input_ids,
|
|
"decoder_attention_mask": decoder_attn_mask,
|
|
}
|
|
return dummy_inputs
|
|
|
|
|
|
def _make_linear_from_emb(emb):
|
|
vocab_size, emb_size = emb.weight.shape
|
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
|
lin_layer.weight.data = emb.weight.data # .T
|
|
return lin_layer
|
|
|
|
|
|
# Helper Functions, mostly for making masks
|
|
def _check_shapes(shape_1, shape2):
|
|
if shape_1 != shape2:
|
|
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
|
|
|
|
|
def _combine_masks(key_padding_mask, attn_mask, targ_size):
|
|
# targ_size = (bsz, tgt_len, src_len)
|
|
a = torch.zeros(targ_size)
|
|
b = torch.zeros(targ_size)
|
|
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
|
|
_check_shapes(key_padding_mask.shape, targ_size[:2])
|
|
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
|
|
a[reshaped] = 1e-8
|
|
|
|
if attn_mask is not None: # (tgt_len, src_len) -> targ_size
|
|
_check_shapes(attn_mask.shape, targ_size[-2:])
|
|
b = attn_mask.unsqueeze(0).expand(*targ_size)
|
|
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)
|
|
|
|
|
|
def shift_tokens_right(input_ids, pad_token_id):
|
|
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
|
|
prev_output_tokens = input_ids.clone()
|
|
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
|
|
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
|
|
prev_output_tokens[:, 1:] = input_ids[:, :-1]
|
|
return prev_output_tokens
|
|
|
|
|
|
def make_padding_mask(input_ids, padding_idx=1):
|
|
"""True for pad tokens"""
|
|
padding_mask = input_ids.eq(padding_idx)
|
|
if not padding_mask.any():
|
|
padding_mask = None
|
|
return padding_mask
|
|
|
|
|
|
# Helper Modules
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
def __init__(self, config: BartConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.d_model
|
|
self.output_attentions = config.output_attentions
|
|
self.self_attn = SelfAttention(
|
|
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
|
|
)
|
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
|
self.dropout = config.dropout
|
|
self.activation_fn = F.gelu
|
|
self.activation_dropout = config.activation_dropout
|
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
|
|
|
def forward(self, x, encoder_padding_mask):
|
|
"""
|
|
Args:
|
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
|
`(batch, src_len)` where padding elements are indicated by ``1``.
|
|
for t_tgt, t_src is excluded (or masked out), =0 means it is
|
|
included in attention
|
|
|
|
Returns:
|
|
encoded output of shape `(seq_len, batch, embed_dim)`
|
|
"""
|
|
residual = x
|
|
x, attn_weights = self.self_attn(
|
|
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions,
|
|
)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
x = self.self_attn_layer_norm(x)
|
|
|
|
residual = x
|
|
x = self.activation_fn(self.fc1(x))
|
|
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
|
x = self.fc2(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
x = self.final_layer_norm(x)
|
|
return x, attn_weights
|
|
|
|
|
|
class BartEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
|
|
is a :class:`EncoderLayer`.
|
|
|
|
Args:
|
|
config: BartConfig
|
|
"""
|
|
|
|
def __init__(self, config: BartConfig, embed_tokens):
|
|
super().__init__()
|
|
|
|
self.dropout = config.dropout
|
|
self.layerdrop = config.encoder_layerdrop
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
|
|
embed_dim = embed_tokens.embedding_dim
|
|
self.padding_idx = embed_tokens.padding_idx
|
|
self.max_source_positions = config.max_position_embeddings
|
|
|
|
self.embed_tokens = embed_tokens
|
|
|
|
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
|
|
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
self.layernorm_embedding = LayerNorm(embed_dim)
|
|
|
|
def forward(
|
|
self, input_ids=None, attention_mask=None,
|
|
):
|
|
"""
|
|
Args:
|
|
input_ids (LongTensor): tokens in the source language of shape
|
|
`(batch, src_len)`
|
|
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
|
|
Returns:
|
|
namedtuple:
|
|
- **x** (Tensor): the last encoder layer's output of
|
|
shape `(src_len, batch, embed_dim)`
|
|
|
|
- **encoder_states** (List[Tensor]): all intermediate
|
|
hidden states of shape `(src_len, batch, embed_dim)`.
|
|
Only populated if *return_all_hiddens* is True.
|
|
- **all_attentions** (List[Tensor]): Attention weights for each layer.
|
|
During training might not be of length n_layers because of layer dropout.
|
|
"""
|
|
# check attention mask and invert
|
|
if attention_mask is not None:
|
|
assert attention_mask.dim() == 2
|
|
|
|
attention_mask = (1.0 - attention_mask.long()) * -10000.0
|
|
assert attention_mask.max() <= 0
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
embed_pos = self.embed_positions(input_ids)
|
|
x = inputs_embeds + embed_pos
|
|
x = self.layernorm_embedding(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
|
# B x T x C -> T x B x C
|
|
x = x.transpose(0, 1)
|
|
|
|
encoder_states, all_attentions = [], []
|
|
|
|
# encoder layers
|
|
for encoder_layer in self.layers:
|
|
|
|
if self.output_hidden_states:
|
|
encoder_states.append(x)
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
dropout_probability = random.uniform(0, 1)
|
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
|
attn = None
|
|
else:
|
|
x, attn = encoder_layer(x, attention_mask)
|
|
|
|
if self.output_attentions:
|
|
all_attentions.append(attn)
|
|
|
|
if self.output_hidden_states:
|
|
encoder_states.append(x)
|
|
|
|
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
|
|
|
|
return x, encoder_states, all_attentions
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
def __init__(self, config: BartConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.d_model
|
|
self.self_attn = SelfAttention(
|
|
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
|
|
)
|
|
self.dropout = config.dropout
|
|
self.activation_fn = F.gelu
|
|
self.activation_dropout = config.activation_dropout
|
|
|
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
|
self.encoder_attn = SelfAttention(
|
|
self.embed_dim,
|
|
config.decoder_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
encoder_decoder_attention=True,
|
|
)
|
|
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
|
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
|
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
encoder_hidden_states,
|
|
encoder_attn_mask=None,
|
|
layer_state=None,
|
|
attention_mask=None,
|
|
need_attn_weights=False,
|
|
):
|
|
"""
|
|
Args:
|
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
encoder_attn_mask (ByteTensor, optional): binary
|
|
ByteTensor of shape `(batch, src_len)` where padding
|
|
elements are indicated by ``1``.
|
|
need_attn_weights (bool, optional): return attention weights
|
|
for each head (default: return average over heads).
|
|
|
|
Returns:
|
|
encoded output of shape `(seq_len, batch, embed_dim)`
|
|
"""
|
|
|
|
residual = x
|
|
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
|
|
|
|
if layer_state is None:
|
|
layer_state = {}
|
|
# next line mutates layer state
|
|
x, self_attn_weights = self.self_attn(
|
|
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
|
|
)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
x = self.self_attn_layer_norm(x)
|
|
residual = x
|
|
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
|
|
|
x, encoder_attn_weights = self.encoder_attn(
|
|
query=x,
|
|
key=encoder_hidden_states, # could be None
|
|
value=encoder_hidden_states,
|
|
key_padding_mask=encoder_attn_mask,
|
|
layer_state=layer_state, # mutates layer state
|
|
static_kv=True,
|
|
need_weights=False, # not returning it so why compute it
|
|
)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
|
|
x = self.encoder_attn_layer_norm(x)
|
|
|
|
residual = x
|
|
x = self.activation_fn(self.fc1(x))
|
|
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
|
x = self.fc2(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
x = self.final_layer_norm(x)
|
|
return (
|
|
x,
|
|
self_attn_weights,
|
|
layer_state,
|
|
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
|
|
|
|
|
class BartDecoder(nn.Module):
|
|
"""
|
|
Transformer decoder consisting of *config.decoder_layers* layers. Each layer
|
|
is a :class:`DecoderLayer`.
|
|
Args:
|
|
config: BartConfig
|
|
embed_tokens (torch.nn.Embedding): output embedding
|
|
"""
|
|
|
|
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
|
|
super().__init__()
|
|
self.output_past = config.output_past
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
self.dropout = config.dropout
|
|
self.layerdrop = config.decoder_layerdrop
|
|
self.padding_idx = embed_tokens.padding_idx
|
|
self.max_target_positions = config.max_position_embeddings
|
|
self.embed_tokens = embed_tokens
|
|
self.embed_positions = LearnedPositionalEmbedding(
|
|
config.max_position_embeddings, config.d_model, self.padding_idx,
|
|
)
|
|
self.layers = nn.ModuleList(
|
|
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
|
) # type: List[DecoderLayer]
|
|
self.layernorm_embedding = LayerNorm(config.d_model)
|
|
self.generation_mode = False
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
encoder_hidden_states,
|
|
encoder_padding_mask,
|
|
combined_mask,
|
|
decoder_cached_states=None,
|
|
**unused
|
|
):
|
|
"""
|
|
Includes several features from "Jointly Learning to Align and
|
|
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
|
|
|
Args:
|
|
input_ids (LongTensor): previous decoder outputs of shape
|
|
`(batch, tgt_len)`, for teacher forcing
|
|
encoder_hidden_states: output from the encoder, used for
|
|
encoder-side attention
|
|
encoder_padding_mask: for ignoring pad tokens
|
|
decoder_cached_states (dict or None): dictionary used for storing state during generation
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
|
- hidden states
|
|
- attentions
|
|
"""
|
|
# check attention mask and invert
|
|
if encoder_padding_mask is not None:
|
|
assert encoder_padding_mask.dim() == 2
|
|
|
|
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
|
|
assert encoder_padding_mask.max() <= 0
|
|
|
|
# embed positions
|
|
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
|
|
|
|
if self.generation_mode:
|
|
input_ids = input_ids[:, -1:]
|
|
positions = positions[:, -1:] # happens after we embed them
|
|
assert input_ids.ne(self.padding_idx).any()
|
|
|
|
x = self.embed_tokens(input_ids)
|
|
x += positions
|
|
|
|
x = self.layernorm_embedding(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = x.transpose(0, 1) # (seq_len, BS, model_dim)
|
|
# decoder layers
|
|
all_hidden_states = ()
|
|
all_self_attns = ()
|
|
next_decoder_cache = []
|
|
|
|
for i, decoder_layer in enumerate(self.layers):
|
|
decoder_layer # type: DecoderLayer
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
dropout_probability = random.uniform(0, 1)
|
|
if self.training and (dropout_probability < self.layerdrop):
|
|
continue
|
|
|
|
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
|
x, layer_self_attn, layer_past = decoder_layer(
|
|
x,
|
|
encoder_hidden_states,
|
|
encoder_padding_mask,
|
|
layer_state=layer_state,
|
|
attention_mask=combined_mask,
|
|
need_attn_weights=self.output_attentions,
|
|
)
|
|
|
|
if self.output_past:
|
|
next_decoder_cache.append(layer_past.copy())
|
|
if self.output_hidden_states:
|
|
all_hidden_states += (x,)
|
|
if self.output_attentions:
|
|
all_self_attns += (layer_self_attn,)
|
|
|
|
# Convert shapes from (seq_len, BS, model_dim) to (BS, seq_len, model_dim)
|
|
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
|
x = x.transpose(0, 1)
|
|
|
|
if self.output_past:
|
|
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
|
else:
|
|
next_cache = None
|
|
return x, next_cache, all_hidden_states, list(all_self_attns)
|
|
|
|
|
|
def reorder_attn_buffer(input_buffer, new_order):
|
|
"""Reorder buffered internal state (for incremental generation)."""
|
|
# input_buffer = self._get_input_buffer(incremental_state)
|
|
for k in input_buffer.keys():
|
|
input_buffer_k = input_buffer[k]
|
|
if input_buffer_k is not None:
|
|
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
|
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
|
return input_buffer
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
"""Multi-headed attention from "Attention Is All You Need"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
num_heads,
|
|
kdim=None,
|
|
vdim=None,
|
|
dropout=0.0,
|
|
bias=True,
|
|
encoder_decoder_attention=False, # otherwise self_attention
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.kdim = kdim if kdim is not None else embed_dim
|
|
self.vdim = vdim if vdim is not None else embed_dim
|
|
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
|
self.scaling = self.head_dim ** -0.5
|
|
|
|
self.encoder_decoder_attention = encoder_decoder_attention
|
|
qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim # True for all BART
|
|
|
|
assert self.encoder_decoder_attention or qkv_same_dim, (
|
|
"Self-attention requires query, key and " "value to be of the same size"
|
|
)
|
|
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
|
|
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
|
|
|
|
def _shape(self, tensor, dim_0, bsz):
|
|
return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
|
|
|
def forward(
|
|
self,
|
|
query,
|
|
key: Optional[Tensor],
|
|
value: Optional[Tensor],
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
need_weights: bool = False,
|
|
static_kv: bool = False,
|
|
attn_mask: Optional[Tensor] = None,
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
"""Input shape: Time(SeqLen) x Batch x Channel
|
|
|
|
Args:
|
|
|
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
|
keys that are pads, of shape `(batch, src_len)`, where
|
|
padding elements are indicated by 1s.
|
|
need_weights (bool, optional): return the attention weights,
|
|
averaged over heads (default: False).
|
|
attn_mask (ByteTensor, optional): typically used to
|
|
implement causal attention, where the mask prevents the
|
|
attention from looking forward in time (default: None).
|
|
"""
|
|
tgt_len, bsz, embed_dim = query.size()
|
|
assert embed_dim == self.embed_dim
|
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
|
# get here for encoder decoder cause of static_kv
|
|
if layer_state is not None: # get the last k,v and mask for reuse
|
|
saved_state = layer_state.get(self.cache_key, {})
|
|
if "prev_key" in saved_state:
|
|
# previous time steps are cached - no need to recompute key and value if they are static
|
|
if static_kv:
|
|
assert self.encoder_decoder_attention
|
|
key = value = None
|
|
else:
|
|
saved_state = None
|
|
layer_state = {}
|
|
|
|
q = self.q_proj(query) * self.scaling
|
|
if self.encoder_decoder_attention:
|
|
if key is None:
|
|
assert value is None
|
|
k = v = None
|
|
else:
|
|
k = self.k_proj(key)
|
|
v = self.v_proj(key)
|
|
else:
|
|
k = self.k_proj(query)
|
|
v = self.v_proj(query)
|
|
|
|
q = self._shape(q, tgt_len, bsz)
|
|
if k is not None:
|
|
k = self._shape(k, -1, bsz)
|
|
if v is not None:
|
|
v = self._shape(v, -1, bsz)
|
|
|
|
if saved_state is not None:
|
|
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
|
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
|
|
|
|
# Update cache
|
|
layer_state[self.cache_key] = {
|
|
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
|
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
|
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
|
|
}
|
|
|
|
assert k is not None
|
|
src_len = k.size(1)
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
|
|
|
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
if attn_mask is not None:
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
|
key_padding_mask = None
|
|
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
|
|
|
|
if key_padding_mask is not None: # don't attend to padding symbols
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
|
|
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
|
|
|
|
assert v is not None
|
|
attn_output = torch.bmm(attn_probs, v)
|
|
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
|
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
|
attn_output = self.out_proj(attn_output)
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
return attn_output, attn_weights
|
|
|
|
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
|
if "prev_key" in saved_state:
|
|
_prev_key = saved_state["prev_key"]
|
|
assert _prev_key is not None
|
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
|
if static_kv:
|
|
k = prev_key
|
|
else:
|
|
assert k is not None
|
|
k = torch.cat([prev_key, k], dim=1)
|
|
if "prev_value" in saved_state:
|
|
_prev_value = saved_state["prev_value"]
|
|
assert _prev_value is not None
|
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
|
if static_kv:
|
|
v = prev_value
|
|
else:
|
|
assert v is not None
|
|
v = torch.cat([prev_value, v], dim=1)
|
|
assert k is not None and v is not None
|
|
prev_key_padding_mask = saved_state.get("prev_key_padding_mask", None) # type: Optional[Tensor]
|
|
key_padding_mask = self._cat_prev_key_padding_mask(
|
|
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
|
|
)
|
|
return k, v, key_padding_mask
|
|
|
|
@staticmethod
|
|
def _cat_prev_key_padding_mask(
|
|
key_padding_mask: Optional[Tensor],
|
|
prev_key_padding_mask: Optional[Tensor],
|
|
batch_size: int,
|
|
src_len: int,
|
|
static_kv: bool,
|
|
) -> Optional[Tensor]:
|
|
# saved key padding masks have shape (bsz, seq_len)
|
|
if prev_key_padding_mask is not None and static_kv:
|
|
new_key_padding_mask = prev_key_padding_mask
|
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
|
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
|
# During incremental decoding, as the padding token enters and
|
|
# leaves the frame, there will be a time when prev or current is None
|
|
elif prev_key_padding_mask is not None:
|
|
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
|
|
if prev_key_padding_mask.is_cuda:
|
|
filler = filler.to(prev_key_padding_mask.device)
|
|
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
|
elif key_padding_mask is not None:
|
|
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
|
|
if key_padding_mask.is_cuda:
|
|
filler = filler.cuda()
|
|
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
|
else:
|
|
new_key_padding_mask = prev_key_padding_mask
|
|
return new_key_padding_mask
|
|
|
|
|
|
class BartClassificationHead(nn.Module):
|
|
"""Head for sentence-level classification tasks."""
|
|
|
|
# This can trivially be shared with RobertaClassificationHead
|
|
|
|
def __init__(
|
|
self, input_dim, inner_dim, num_classes, pooler_dropout,
|
|
):
|
|
super().__init__()
|
|
self.dense = nn.Linear(input_dim, inner_dim)
|
|
self.dropout = nn.Dropout(p=pooler_dropout)
|
|
self.out_proj = nn.Linear(inner_dim, num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.dropout(x)
|
|
x = self.dense(x)
|
|
x = torch.tanh(x)
|
|
x = self.dropout(x)
|
|
x = self.out_proj(x)
|
|
return x
|
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Embedding):
|
|
"""
|
|
This module learns positional embeddings up to a fixed maximum size.
|
|
Padding ids are ignored by either offsetting based on padding_idx
|
|
or by setting padding_idx to None and ensuring that the appropriate
|
|
position ids are passed to the forward function.
|
|
"""
|
|
|
|
def __init__(
|
|
self, num_embeddings: int, embedding_dim: int, padding_idx: int,
|
|
):
|
|
# if padding_idx is specified then offset the embedding ids by
|
|
# this index and adjust num_embeddings appropriately
|
|
assert padding_idx is not None
|
|
num_embeddings += padding_idx + 1 # WHY?
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
|
|
|
def forward(self, input, generation_mode=False):
|
|
"""Input is expected to be of size [bsz x seqlen]."""
|
|
if generation_mode: # the position is our current step in the decoded sequence
|
|
pos = int(self.padding_idx + input.size(1))
|
|
positions = input.data.new(1, 1).fill_(pos)
|
|
else:
|
|
positions = create_position_ids_from_input_ids(input, self.padding_idx)
|
|
return super().forward(positions)
|
|
|
|
|
|
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
|
if torch.cuda.is_available():
|
|
try:
|
|
from apex.normalization import FusedLayerNorm
|
|
|
|
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
|
except ImportError:
|
|
pass
|
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
|
|
|
|
|
def fill_with_neg_inf(t):
|
|
"""FP16-compatible function that fills a input_ids with -inf."""
|
|
return t.float().fill_(float("-inf")).type_as(t)
|
|
|
|
|
|
def _filter_out_falsey_values(tup) -> Tuple:
|
|
"""Remove entries that are None or [] from an iterable."""
|
|
return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)
|
|
|
|
|
|
# Public API
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING,
|
|
)
|
|
class BartModel(PretrainedBartModel):
|
|
def __init__(self, config: BartConfig):
|
|
super().__init__(config)
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
|
|
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
|
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
|
|
|
self.encoder = BartEncoder(config, self.shared)
|
|
self.decoder = BartDecoder(config, self.shared)
|
|
|
|
self.init_weights()
|
|
|
|
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
attention_mask=None,
|
|
decoder_input_ids=None,
|
|
encoder_outputs=None, # type: Tuple
|
|
decoder_attention_mask=None,
|
|
decoder_cached_states=None,
|
|
):
|
|
|
|
# make masks if user doesn't supply
|
|
if not self.decoder.generation_mode:
|
|
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
|
|
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
|
|
)
|
|
assert decoder_input_ids is not None
|
|
if encoder_outputs is None:
|
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
|
assert isinstance(encoder_outputs, tuple)
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
decoder_outputs = self.decoder(
|
|
decoder_input_ids,
|
|
encoder_outputs[0],
|
|
attention_mask,
|
|
decoder_attention_mask,
|
|
decoder_cached_states=decoder_cached_states,
|
|
)
|
|
# Attention and hidden_states will be [] or None if they aren't needed
|
|
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
|
|
assert isinstance(decoder_outputs[0], torch.Tensor)
|
|
encoder_outputs = _filter_out_falsey_values(encoder_outputs) # type: tuple
|
|
return decoder_outputs + encoder_outputs
|
|
|
|
def get_input_embeddings(self):
|
|
return self.shared
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.shared = value
|
|
|
|
def get_output_embeddings(self):
|
|
return _make_linear_from_emb(self.shared) # make it on the fly
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING,
|
|
)
|
|
class BartForConditionalGeneration(PretrainedBartModel):
|
|
base_model_prefix = "model"
|
|
|
|
def __init__(self, config: BartConfig):
|
|
super().__init__(config)
|
|
# if base_model is None:
|
|
base_model = BartModel(config)
|
|
self.model = base_model
|
|
self.lm_head = _make_linear_from_emb(self.model.shared)
|
|
|
|
def tie_weights(self):
|
|
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
|
|
|
|
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
attention_mask=None,
|
|
encoder_outputs=None,
|
|
decoder_input_ids=None,
|
|
decoder_attention_mask=None,
|
|
decoder_cached_states=None,
|
|
lm_labels=None,
|
|
**unused
|
|
):
|
|
r"""
|
|
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Labels for computing the masked language modeling loss.
|
|
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
|
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
|
|
with labels
|
|
in ``[0, ..., config.vocab_size]``.
|
|
|
|
Returns:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
|
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
Masked language modeling loss.
|
|
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
|
|
Examples::
|
|
|
|
# Mask filling only works for bart-large
|
|
from transformers import BartTokenizer, BartForConditionalGeneration
|
|
tokenizer = AutoTokenizer.from_pretrained('bart-large')
|
|
TXT = "My friends are <mask> but they eat too many carbs."
|
|
model = BartForConditionalGeneration.from_pretrained('bart-large')
|
|
input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
|
|
logits = model(input_ids)[0]
|
|
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
|
probs = logits[0, masked_index].softmax(dim=0)
|
|
values, predictions = probs.topk(5)
|
|
tokenizer.decode(predictions).split()
|
|
# ['good', 'great', 'all', 'really', 'very']
|
|
"""
|
|
outputs = self.model(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
encoder_outputs=encoder_outputs,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
decoder_cached_states=decoder_cached_states,
|
|
)
|
|
lm_logits = self.lm_head(outputs[0])
|
|
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
|
if lm_labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
# TODO(SS): do we need to ignore pad tokens in lm_labels?
|
|
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
|
outputs = (masked_lm_loss,) + outputs
|
|
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
|
|
if past is None: # first step
|
|
encoder_outputs, decoder_cached_states = None, None
|
|
else:
|
|
encoder_outputs, decoder_cached_states = past
|
|
return {
|
|
"input_ids": input_ids, # ignored after first pass
|
|
"decoder_cached_states": decoder_cached_states,
|
|
"decoder_input_ids": decoder_input_ids,
|
|
"encoder_outputs": encoder_outputs,
|
|
"attention_mask": attention_mask,
|
|
# "decoder_attention_mask": decoder_attention_mask,
|
|
}
|
|
|
|
@staticmethod
|
|
def _reorder_cache(past, beam_idx):
|
|
((enc_out, enc_mask), decoder_cached_states) = past
|
|
reordered_past = []
|
|
for layer_past in decoder_cached_states:
|
|
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
|
layer_past_new = {
|
|
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
|
}
|
|
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
|
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
|
reordered_past.append(layer_past_new)
|
|
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
|
|
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
|
|
|
past = ((new_enc_out, new_enc_mask), reordered_past)
|
|
return past
|
|
|
|
def get_output_embeddings(self):
|
|
return self.lm_head
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
input_ids,
|
|
attention_mask=None,
|
|
max_length=20,
|
|
num_beams=1,
|
|
repetition_penalty=1.0,
|
|
length_penalty=1.0,
|
|
num_return_sequences=1,
|
|
min_len=0,
|
|
no_repeat_ngram_size=0,
|
|
):
|
|
r""" Generates summaries using the lm-head and greedy beam search
|
|
|
|
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
|
|
|
|
.. _`XLM beam search code`:
|
|
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
|
.. _`Fairseq beam search code`:
|
|
https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py
|
|
|
|
|
|
Parameters:
|
|
|
|
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
|
The sequence used as a prompt for the generation. If `None` the method initializes
|
|
it as an empty `torch.LongTensor` of shape `(1,)`.
|
|
|
|
max_length: (`optional`) int
|
|
The max length of the sequence to be generated. Does not include tokens in input_ids.
|
|
|
|
num_beams: (`optional`) int
|
|
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
|
|
|
repetition_penalty: (`optional`) float
|
|
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
|
|
|
length_penalty: (`optional`) float
|
|
Exponential penalty to the length. Default to 1.
|
|
|
|
num_return_sequences: (`optional`) int
|
|
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
|
|
|
min_len: (`optional`) int
|
|
|
|
Returns:
|
|
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
|
sequence_length is <= max_length (examples can finish early)
|
|
|
|
Examples::
|
|
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
|
|
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
|
|
config = BartConfig(vocab_size=50264, output_past=True) # no mask_token_id
|
|
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn', config=config)
|
|
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
|
|
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
|
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
|
# Generate Summary
|
|
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
|
|
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
|
"""
|
|
bos_token_id = self.config.bos_token_id
|
|
pad_token_id = self.config.pad_token_id
|
|
eos_token_id = self.config.eos_token_id
|
|
batch_size, cur_len = input_ids.shape
|
|
assert input_ids is not None
|
|
assert self.config.output_past, "Generating with bart requires instantiating a config with output_past=True"
|
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
|
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
|
assert isinstance(pad_token_id, int)
|
|
assert bos_token_id == 0, "configurable bos_token_id not yet supported"
|
|
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
|
assert (
|
|
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
|
), "`num_return_sequences` should be a positive integer."
|
|
|
|
# current position and vocab size
|
|
cur_len = input_ids.shape[1]
|
|
vocab_size = self.config.vocab_size
|
|
|
|
if num_return_sequences != 1:
|
|
# Expand input to num return sequences
|
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
|
input_ids = input_ids.contiguous().view(
|
|
batch_size * num_return_sequences, cur_len
|
|
) # shape: (batch_size * num_return_sequences, cur_len)
|
|
batch_size *= num_return_sequences
|
|
|
|
# Below here somewhat similar to PretrainedModel._generate_beam_search
|
|
# Expand input to num beams
|
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
|
|
|
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
|
if attention_mask is not None:
|
|
attention_mask = (
|
|
attention_mask.unsqueeze(1)
|
|
.expand(batch_size, num_beams, cur_len)
|
|
.contiguous()
|
|
.view(batch_size * num_beams, cur_len)
|
|
) # RESHAPE
|
|
|
|
# generated hypotheses
|
|
finalized_hyps = [ # they end in EOS and we wont work on them more!
|
|
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=True) for _ in range(batch_size)
|
|
]
|
|
|
|
# scores for each sentence in the beam
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
|
beam_scores[:, 1:] = -1e9 # avoid ties in first step
|
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
|
|
|
# decoder tokens
|
|
prev_output_tokens = input_ids.new(batch_size * num_beams, 1).long().fill_(-1)
|
|
prev_output_tokens[:, 0] = 2 # HARDCODED EOS, which will be removed at the end.
|
|
decoder_cache = None
|
|
done = [False for _ in range(batch_size)] # done sentences
|
|
|
|
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
|
|
for step in range(max_length + 1):
|
|
decoder_input_ids = prev_output_tokens.clone()
|
|
model_inputs = self.prepare_inputs_for_generation(
|
|
input_ids, decoder_cache, decoder_input_ids, attention_mask,
|
|
)
|
|
outputs = self(**model_inputs)
|
|
lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
|
|
|
|
lprobs[lprobs != lprobs] = -math.inf # block nans
|
|
lprobs[:, pad_token_id] = -math.inf
|
|
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
|
|
|
if step == 0: # Force BOS to be chosen
|
|
lprobs[:, bos_token_id + 1 :] = -math.inf
|
|
elif step < min_len: # Prevent EOS from being chosen
|
|
lprobs[:, eos_token_id] = -math.inf
|
|
elif step == max_length: # FORCE EOS to be chosen
|
|
lprobs[:, :eos_token_id] = -math.inf
|
|
lprobs[:, eos_token_id + 1 :] = -math.inf
|
|
assert self._do_output_past(outputs)
|
|
decoder_cache = outputs[1]
|
|
if repetition_penalty != 1.0:
|
|
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
|
|
num_hypos = batch_size * num_beams
|
|
if no_repeat_ngram_size > 0: # copied from fairseq
|
|
# for each sentence, calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
banned_tokens = self.calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step)
|
|
# then set their probabilities tof -inf
|
|
for idx in range(num_hypos):
|
|
lprobs[idx, banned_tokens[idx]] = -math.inf
|
|
assert lprobs.size() == (batch_size * num_beams, vocab_size)
|
|
_scores = lprobs + beam_scores[:, None].expand_as(lprobs) # (batch_size * num_beams, vocab_size)
|
|
|
|
# re-organize to group the beam together (we are keeping top hypothesis across beams)
|
|
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
|
# Take the best 2 x beam_size predictions for each example, we'll choose the first beam_size of these which don't predict eos to continue with.
|
|
next_scores, next_words = torch.topk(_scores, 2 * num_beams)
|
|
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
|
|
|
# list of (batch_size * num_beams)
|
|
next_batch_beam = [] # Tuple(next score, next word, current position in the batch)
|
|
for batch_idx in range(batch_size):
|
|
# if we are done with this sentence (because we can't improve)
|
|
if done[batch_idx]: # then pad all associated hypotheses
|
|
assert (
|
|
len(finalized_hyps[batch_idx]) >= num_beams
|
|
), "Example can only be done if at least {} beams have been generated".format(num_beams)
|
|
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
|
continue
|
|
|
|
# Otherwise generate some next word choices
|
|
next_sent_beam = []
|
|
# add next words for this sentence
|
|
for i, (idx, score) in enumerate(zip(next_words[batch_idx], next_scores[batch_idx])):
|
|
beam_id = idx // vocab_size
|
|
word_id = idx % vocab_size
|
|
assert prev_output_tokens.shape[1] == (step + 1)
|
|
if word_id.item() == eos_token_id:
|
|
if i >= num_beams:
|
|
continue
|
|
finalized_hyps[batch_idx].add(
|
|
prev_output_tokens[batch_idx * num_beams + beam_id].clone(), score.item(),
|
|
)
|
|
else:
|
|
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
|
|
|
|
if len(next_sent_beam) == num_beams: # TODO(SS): can we delete this?
|
|
break
|
|
# Check if were done so that we can save a pad step if all(done)
|
|
done[batch_idx] = done[batch_idx] or finalized_hyps[batch_idx].is_done(
|
|
next_scores[batch_idx].max().item(), cur_len=step + 1,
|
|
)
|
|
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
|
next_batch_beam.extend(next_sent_beam)
|
|
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
|
|
|
if all(done):
|
|
break
|
|
|
|
# sanity check / prepare next batch
|
|
assert len(next_batch_beam) == batch_size * num_beams
|
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
|
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
|
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
|
# re-order decoder inputs to [beam_idx]
|
|
prev_output_tokens = prev_output_tokens[beam_idx]
|
|
prev_output_tokens = torch.cat([prev_output_tokens, beam_words.unsqueeze(1)], dim=-1)
|
|
|
|
# re-order internal states
|
|
decoder_cache = self._reorder_cache(decoder_cache, beam_idx)
|
|
|
|
for batch_idx in range(batch_size):
|
|
# Add all open beam hypothesis to generated_hyps
|
|
if done[batch_idx]:
|
|
continue
|
|
offset = batch_idx * num_beams
|
|
for i in range(num_beams):
|
|
score = beam_scores[offset + i]
|
|
final_tokens = prev_output_tokens[offset + i]
|
|
finalized_hyps[batch_idx].add(final_tokens, score.item())
|
|
|
|
# select the best hypotheses
|
|
sent_lengths = input_ids.new(batch_size)
|
|
best = []
|
|
for i, hypotheses in enumerate(finalized_hyps):
|
|
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
|
|
sent_lengths[i] = len(best_hyp)
|
|
best.append(best_hyp)
|
|
|
|
# shorter batches are filled with pad_token
|
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
|
# TODO(SS): decoded = torch.rnn.utils.pad_sequence(best, batch_first=True, padding_value=pad_token_id)
|
|
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) # TODO(SS): same as step?
|
|
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
|
|
# fill with hypothesis and eos_token_id if necessary
|
|
for i, hypo in enumerate(best):
|
|
decoded[i, : sent_lengths[i]] = hypo
|
|
if sent_lengths[i] < max_length:
|
|
decoded[i, sent_lengths[i]] = eos_token_id
|
|
else:
|
|
assert (len(hypo) == max_length for hypo in best)
|
|
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
|
return decoded[:, 1:] # get rid of starting EOS
|
|
|
|
@staticmethod
|
|
def calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step):
|
|
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
|
# TODO(SS): this can go on parent if there is demand
|
|
if step + 2 < no_repeat_ngram_size:
|
|
return [
|
|
[] for _ in range(num_hypos)
|
|
] # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
|
gen_ngrams = [{} for _ in range(num_hypos)]
|
|
for idx in range(num_hypos):
|
|
gen_tokens = prev_output_tokens[idx].tolist()
|
|
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
|
k = tuple(ngram[:-1])
|
|
gen_ngrams[idx][k] = gen_ngrams[idx].get(k, []) + [ngram[-1]]
|
|
|
|
def _get_generated_ngrams(hypo_idx):
|
|
"""Before decoding the next token, prevent decoding of ngrams that have already appeared"""
|
|
ngram_index = tuple(prev_output_tokens[hypo_idx, step + 2 - no_repeat_ngram_size : step + 1].tolist())
|
|
return gen_ngrams[hypo_idx].get(ngram_index, [])
|
|
|
|
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
|
return banned_tokens
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
|
|
BART_START_DOCSTRING,
|
|
)
|
|
class BartForSequenceClassification(PretrainedBartModel):
|
|
def __init__(self, config: BartConfig, **kwargs):
|
|
super().__init__(config, **kwargs)
|
|
self.model = BartModel(config)
|
|
self.classification_head = BartClassificationHead(
|
|
config.d_model, config.d_model, config.num_labels, config.classif_dropout,
|
|
)
|
|
self.model._init_weights(self.classification_head.dense)
|
|
self.model._init_weights(self.classification_head.out_proj)
|
|
|
|
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
attention_mask=None,
|
|
encoder_outputs=None,
|
|
decoder_input_ids=None,
|
|
decoder_attention_mask=None,
|
|
labels=None,
|
|
):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
|
Labels for computing the sequence classification/regression loss.
|
|
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
|
|
Returns:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
|
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
|
Classification loss (cross entropy)
|
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the
|
|
self-attention
|
|
heads.
|
|
|
|
Examples::
|
|
|
|
from transformers import BartTokenizer, BartForSequenceClassification
|
|
import torch
|
|
|
|
tokenizer = BartTokenizer.from_pretrained('bart-large')
|
|
model = BartForSequenceClassification.from_pretrained('bart-large')
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute",
|
|
add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids, labels=labels)
|
|
loss, logits = outputs[:2]
|
|
|
|
"""
|
|
outputs = self.model(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
encoder_outputs=encoder_outputs,
|
|
)
|
|
x = outputs[0] # last hidden state
|
|
eos_mask = input_ids.eq(self.config.eos_token_id)
|
|
if len(torch.unique(eos_mask.sum(1))) > 1:
|
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
|
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
|
logits = self.classification_head(sentence_representation)
|
|
# Prepend logits
|
|
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
|
|
if labels is not None: # prepend loss to output,
|
|
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
outputs = (loss,) + outputs
|
|
|
|
return outputs
|