[Bart/Memory] SelfAttention only returns weights if config.outp… (#3369)
This commit is contained in:
@@ -217,7 +217,9 @@ class EncoderLayer(nn.Module):
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,)
|
||||
x, attn_weights = self.self_attn(
|
||||
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
@@ -316,6 +318,7 @@ class DecoderLayer(nn.Module):
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.output_attentions = config.output_attentions
|
||||
self.self_attn = SelfAttention(
|
||||
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
|
||||
)
|
||||
@@ -343,14 +346,16 @@ class DecoderLayer(nn.Module):
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,)
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
|
||||
x, encoder_attn_weights = self.encoder_attn(
|
||||
x, _ = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states,
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
@@ -527,6 +532,7 @@ class SelfAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
need_weights=False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||
static_kv = self.encoder_decoder_attention # type: bool
|
||||
@@ -598,7 +604,10 @@ class SelfAttention(nn.Module):
|
||||
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
if need_weights:
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||||
|
||||
Reference in New Issue
Block a user