|
|
|
|
@@ -21,6 +21,7 @@ import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import Tensor, nn
|
|
|
|
|
|
|
|
|
|
from .activations import ACT2FN
|
|
|
|
|
from .configuration_bart import BartConfig
|
|
|
|
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
|
|
|
|
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
|
|
|
|
|
@@ -196,7 +197,7 @@ class EncoderLayer(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
|
|
|
|
self.dropout = config.dropout
|
|
|
|
|
self.activation_fn = F.gelu
|
|
|
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
|
|
|
self.activation_dropout = config.activation_dropout
|
|
|
|
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
|
|
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
|
|
|
@@ -278,8 +279,8 @@ class BartEncoder(nn.Module):
|
|
|
|
|
# check attention mask and invert
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
assert attention_mask.dim() == 2
|
|
|
|
|
attention_mask = (1.0 - attention_mask.long()) * LARGE_NEGATIVE
|
|
|
|
|
assert attention_mask.max() <= 0
|
|
|
|
|
attention_mask = attention_mask.eq(0)
|
|
|
|
|
|
|
|
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
|
embed_pos = self.embed_positions(input_ids)
|
|
|
|
|
x = inputs_embeds + embed_pos
|
|
|
|
|
@@ -318,7 +319,7 @@ class DecoderLayer(nn.Module):
|
|
|
|
|
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
|
|
|
|
|
)
|
|
|
|
|
self.dropout = config.dropout
|
|
|
|
|
self.activation_fn = F.gelu
|
|
|
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
|
|
|
self.activation_dropout = config.activation_dropout
|
|
|
|
|
|
|
|
|
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
|
|
|
|
@@ -334,13 +335,7 @@ class DecoderLayer(nn.Module):
|
|
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
encoder_hidden_states,
|
|
|
|
|
encoder_attn_mask=None,
|
|
|
|
|
layer_state=None,
|
|
|
|
|
attention_mask=None,
|
|
|
|
|
need_attn_weights=False,
|
|
|
|
|
self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None,
|
|
|
|
|
):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
|
@@ -437,9 +432,7 @@ class BartDecoder(nn.Module):
|
|
|
|
|
# check attention mask and invert
|
|
|
|
|
if encoder_padding_mask is not None:
|
|
|
|
|
assert encoder_padding_mask.dim() == 2
|
|
|
|
|
|
|
|
|
|
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
|
|
|
|
|
assert encoder_padding_mask.max() <= 0
|
|
|
|
|
encoder_padding_mask = encoder_padding_mask.eq(0)
|
|
|
|
|
|
|
|
|
|
# embed positions
|
|
|
|
|
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
|
|
|
|
|
@@ -469,12 +462,7 @@ class BartDecoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
|
|
|
|
x, layer_self_attn, layer_past = decoder_layer(
|
|
|
|
|
x,
|
|
|
|
|
encoder_hidden_states,
|
|
|
|
|
encoder_padding_mask,
|
|
|
|
|
layer_state=layer_state,
|
|
|
|
|
attention_mask=combined_mask,
|
|
|
|
|
need_attn_weights=self.output_attentions,
|
|
|
|
|
x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.output_past:
|
|
|
|
|
@@ -598,7 +586,7 @@ class SelfAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
if key_padding_mask is not None: # don't attend to padding symbols
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
|
|
|
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
|
|
|
|
|
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
|
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
|
|
|
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
|
|
@@ -648,22 +636,20 @@ class SelfAttention(nn.Module):
|
|
|
|
|
static_kv: bool,
|
|
|
|
|
) -> Optional[Tensor]:
|
|
|
|
|
# saved key padding masks have shape (bsz, seq_len)
|
|
|
|
|
if prev_key_padding_mask is not None and static_kv:
|
|
|
|
|
if prev_key_padding_mask is not None:
|
|
|
|
|
if static_kv:
|
|
|
|
|
new_key_padding_mask = prev_key_padding_mask
|
|
|
|
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
|
|
|
|
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
|
|
|
|
# During incremental decoding, as the padding token enters and
|
|
|
|
|
# leaves the frame, there will be a time when prev or current is None
|
|
|
|
|
elif prev_key_padding_mask is not None:
|
|
|
|
|
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
|
|
|
|
|
if prev_key_padding_mask.is_cuda:
|
|
|
|
|
filler = filler.to(prev_key_padding_mask.device)
|
|
|
|
|
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
|
|
|
|
else:
|
|
|
|
|
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
|
|
|
|
|
|
|
|
|
|
elif key_padding_mask is not None:
|
|
|
|
|
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
|
|
|
|
|
if key_padding_mask.is_cuda:
|
|
|
|
|
filler = filler.cuda()
|
|
|
|
|
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
|
|
|
|
filler = torch.zeros(
|
|
|
|
|
batch_size,
|
|
|
|
|
src_len - key_padding_mask.size(1),
|
|
|
|
|
dtype=key_padding_mask.dtype,
|
|
|
|
|
device=key_padding_mask.device,
|
|
|
|
|
)
|
|
|
|
|
new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1)
|
|
|
|
|
else:
|
|
|
|
|
new_key_padding_mask = prev_key_padding_mask
|
|
|
|
|
return new_key_padding_mask
|
|
|
|
|
|