Files
HuggingFace_transformer/src/transformers/modeling_bart.py
2020-04-02 15:18:33 +02:00

1035 lines
43 KiB
Python

# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
import random
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .activations import ACT2FN
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
logger = logging.getLogger(__name__)
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
}
BART_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matters related to general usage and behavior.
Parameters:
config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
BART_GENERATION_EXAMPLE = r"""
Examples::
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
"""
BART_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
Padding will be ignored by default should you provide it.
Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices in input_ids.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
"""
def invert_mask(attention_mask):
assert attention_mask.dim() == 2
return attention_mask.eq(0)
def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
):
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation
"""
pad_token_id = config.pad_token_id
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
bsz, tgt_len = decoder_input_ids.size()
if decoder_padding_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
return decoder_input_ids, decoder_padding_mask, causal_mask
class PretrainedBartModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
def _make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
if shape_1 != shape2:
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens
def make_padding_mask(input_ids, padding_idx=1):
"""True for pad tokens"""
padding_mask = input_ids.eq(padding_idx)
if not padding_mask.any():
padding_mask = None
return padding_mask
# Helper Modules
class EncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.output_attentions = config.output_attentions
self.self_attn = SelfAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, x, encoder_padding_mask):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.final_layer_norm(x)
return x, attn_weights
class BartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
is a :class:`EncoderLayer`.
Args:
config: BartConfig
"""
def __init__(self, config: BartConfig, embed_tokens):
super().__init__()
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim)
def forward(
self, input_ids, attention_mask=None,
):
"""
Args:
input_ids (LongTensor): tokens in the source language of shape
`(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
Returns:
Tuple comprised of:
- **x** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *self.output_hidden_states:* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
if attention_mask is not None:
attention_mask = invert_mask(attention_mask)
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_states, all_attentions = [], []
for encoder_layer in self.layers:
if self.output_hidden_states:
encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer(x, attention_mask)
if self.output_attentions:
all_attentions.append(attn)
if self.output_hidden_states:
encoder_states.append(x)
# T x B x C -> B x T x C
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
x = x.transpose(0, 1)
return x, encoder_states, all_attentions
class DecoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.output_attentions = config.output_attentions
self.self_attn = SelfAttention(
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = SelfAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(
self,
x,
encoder_hidden_states,
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
decoder_padding_mask=None,
):
residual = x
if layer_state is None:
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn(
query=x,
key=x,
layer_state=layer_state,
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
need_weights=self.output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.encoder_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
class BartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer
is a :class:`DecoderLayer`.
Args:
config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
super().__init__()
self.output_past = config.output_past
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx,
)
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
def forward(
self,
input_ids,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
decoder_cached_states=None,
generation_mode=False,
**unused
):
"""
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
input_ids (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_hidden_states: output from the encoder, used for
encoder-side attention
encoder_padding_mask: for ignoring pad tokens
decoder_cached_states (dict or None): dictionary used for storing state during generation
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- hidden states
- attentions
"""
# check attention mask and invert
if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
if generation_mode:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids)
x += positions
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# decoder layers
all_hidden_states = ()
all_self_attns = ()
next_decoder_cache = []
for i, decoder_layer in enumerate(self.layers):
decoder_layer # type: DecoderLayer
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask,
decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state,
causal_mask=decoder_causal_mask,
)
if self.output_past:
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
if self.output_attentions:
all_self_attns += (layer_self_attn,)
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if self.output_past:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
return x, next_cache, all_hidden_states, list(all_self_attns)
def _reorder_buffer(attn_cache, new_order):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = input_buffer_k.index_select(0, new_order)
return attn_cache
class SelfAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
encoder_decoder_attention=False, # otherwise self_attention
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
def _shape(self, tensor, dim_0, bsz):
return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1)
def forward(
self,
query,
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None,
need_weights=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv = self.encoder_decoder_attention # type: bool
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv
if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state:
# previous time steps are cached - no need to recompute key and value if they are static
if static_kv:
key = None
else:
saved_state = None
layer_state = {}
q = self.q_proj(query) * self.scaling
if static_kv:
if key is None:
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
k = self.k_proj(query)
v = self.v_proj(query)
q = self._shape(q, tgt_len, bsz)
if k is not None:
k = self._shape(k, -1, bsz)
if v is not None:
v = self._shape(v, -1, bsz)
if saved_state is not None:
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
# Update cache
layer_state[self.cache_key] = {
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}
assert k is not None
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
if need_weights:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
else:
attn_weights = None
return attn_output, attn_weights
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
assert k is not None and v is not None
prev_key_padding_mask = saved_state.get("prev_key_padding_mask", None) # type: Optional[Tensor]
key_padding_mask = self._cat_prev_key_padding_mask(
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
)
return k, v, key_padding_mask
@staticmethod
def _cat_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None:
if static_kv:
new_key_padding_mask = prev_key_padding_mask
else:
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
elif key_padding_mask is not None:
filler = torch.zeros(
batch_size,
src_len - key_padding_mask.size(1),
dtype=key_padding_mask.dtype,
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
class BartClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
# This can trivially be shared with RobertaClassificationHead
def __init__(
self, input_dim, inner_dim, num_classes, pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, x):
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(
self, num_embeddings: int, embedding_dim: int, padding_idx: int,
):
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
assert padding_idx is not None
num_embeddings += padding_idx + 1 # WHY?
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input, generation_mode=False):
"""Input is expected to be of size [bsz x seqlen]."""
if generation_mode: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos)
else:
positions = create_position_ids_from_input_ids(input, self.padding_idx)
return super().forward(positions)
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a input_ids with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _filter_out_falsey_values(tup) -> Tuple:
"""Remove entries that are None or [] from an iterable."""
return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)
# Public API
def _get_shape(t):
return getattr(t, "shape", None)
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING,
)
class BartModel(PretrainedBartModel):
def __init__(self, config: BartConfig):
super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
self.encoder = BartEncoder(config, self.shared)
self.decoder = BartDecoder(config, self.shared)
self.init_weights()
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
self,
input_ids,
attention_mask=None,
decoder_input_ids=None,
encoder_outputs=None, # type: Tuple
decoder_attention_mask=None,
decoder_cached_states=None,
generation_mode=False,
):
# make masks if user doesn't supply
if not generation_mode:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
decoder_padding_mask=decoder_attention_mask,
causal_mask_dtype=self.shared.weight.dtype,
)
else:
decoder_padding_mask, causal_mask = None, None
assert decoder_input_ids is not None
if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
assert isinstance(encoder_outputs, tuple)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids,
encoder_outputs[0],
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
assert isinstance(decoder_outputs[0], torch.Tensor)
encoder_outputs = _filter_out_falsey_values(encoder_outputs) # type: tuple
return decoder_outputs + encoder_outputs
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_output_embeddings(self):
return _make_linear_from_emb(self.shared) # make it on the fly
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING + BART_GENERATION_EXAMPLE,
)
class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model"
def __init__(self, config: BartConfig):
super().__init__(config)
base_model = BartModel(config)
self.model = base_model
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_cached_states=None,
lm_labels=None,
generation_mode=False,
**unused
):
r"""
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
with labels
in ``[0, ..., config.vocab_size]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
# Mask filling only works for bart-large
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('bart-large')
TXT = "My friends are <mask> but they eat too many carbs."
model = BartForConditionalGeneration.from_pretrained('bart-large')
input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
logits = model(input_ids)[0]
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
tokenizer.decode(predictions).split()
# ['good', 'great', 'all', 'really', 'very']
"""
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
if lm_labels is not None:
loss_fct = nn.CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in lm_labels?
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step, decoder_cached_states are empty
if not past[1]:
encoder_outputs, decoder_cached_states = past, None
else:
encoder_outputs, decoder_cached_states = past
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"generation_mode": True,
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_cached_states) = past
reordered_past = []
for layer_past in decoder_cached_states:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past)
return past
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
BART_START_DOCSTRING,
)
class BartForSequenceClassification(PretrainedBartModel):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
self.classification_head = BartClassificationHead(
config.d_model, config.d_model, config.num_labels, config.classif_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification loss (cross entropy)
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the
self-attention
heads.
Examples::
from transformers import BartTokenizer, BartForSequenceClassification
import torch
tokenizer = BartTokenizer.from_pretrained('bart-large')
model = BartForSequenceClassification.from_pretrained('bart-large')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute",
add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
logits = self.classification_head(sentence_representation)
# Prepend logits
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
if labels is not None: # prepend loss to output,
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs