[Bart/Memory] SelfAttention only returns weights if config.outp… (#3369)
This commit is contained in:
@@ -217,7 +217,9 @@ class EncoderLayer(nn.Module):
|
|||||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||||
"""
|
"""
|
||||||
residual = x
|
residual = x
|
||||||
x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,)
|
x, attn_weights = self.self_attn(
|
||||||
|
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
|
||||||
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
x = self.self_attn_layer_norm(x)
|
x = self.self_attn_layer_norm(x)
|
||||||
@@ -316,6 +318,7 @@ class DecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: BartConfig):
|
def __init__(self, config: BartConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.d_model
|
self.embed_dim = config.d_model
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
self.self_attn = SelfAttention(
|
self.self_attn = SelfAttention(
|
||||||
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
|
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
|
||||||
)
|
)
|
||||||
@@ -343,14 +346,16 @@ class DecoderLayer(nn.Module):
|
|||||||
if layer_state is None:
|
if layer_state is None:
|
||||||
layer_state = {}
|
layer_state = {}
|
||||||
# next line mutates layer state
|
# next line mutates layer state
|
||||||
x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,)
|
x, self_attn_weights = self.self_attn(
|
||||||
|
query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions
|
||||||
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
x = self.self_attn_layer_norm(x)
|
x = self.self_attn_layer_norm(x)
|
||||||
residual = x
|
residual = x
|
||||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||||
|
|
||||||
x, encoder_attn_weights = self.encoder_attn(
|
x, _ = self.encoder_attn(
|
||||||
query=x,
|
query=x,
|
||||||
key=encoder_hidden_states,
|
key=encoder_hidden_states,
|
||||||
key_padding_mask=encoder_attn_mask,
|
key_padding_mask=encoder_attn_mask,
|
||||||
@@ -527,6 +532,7 @@ class SelfAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
need_weights=False,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||||
static_kv = self.encoder_decoder_attention # type: bool
|
static_kv = self.encoder_decoder_attention # type: bool
|
||||||
@@ -598,7 +604,10 @@ class SelfAttention(nn.Module):
|
|||||||
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
|
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
if need_weights:
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
else:
|
||||||
|
attn_weights = None
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||||||
|
|||||||
Reference in New Issue
Block a user