From 63f4d8cad010f1972254007ad56b22fe5ed203fe Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 26 Mar 2020 18:42:39 -0400 Subject: [PATCH] =?UTF-8?q?[Bart/Memory]=20SelfAttention=20only=20returns?= =?UTF-8?q?=20weights=20if=20config.outp=E2=80=A6=20(#3369)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/modeling_bart.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index ab44440197..e4885682be 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -217,7 +217,9 @@ class EncoderLayer(nn.Module): encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x - x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,) + x, attn_weights = self.self_attn( + query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions + ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) @@ -316,6 +318,7 @@ class DecoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model + self.output_attentions = config.output_attentions self.self_attn = SelfAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) @@ -343,14 +346,16 @@ class DecoderLayer(nn.Module): if layer_state is None: layer_state = {} # next line mutates layer state - x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,) + x, self_attn_weights = self.self_attn( + query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions + ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key - x, encoder_attn_weights = self.encoder_attn( + x, _ = self.encoder_attn( query=x, key=encoder_hidden_states, key_padding_mask=encoder_attn_mask, @@ -527,6 +532,7 @@ class SelfAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, layer_state: Optional[Dict[str, Optional[Tensor]]] = None, attn_mask: Optional[Tensor] = None, + need_weights=False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time(SeqLen) x Batch x Channel""" static_kv = self.encoder_decoder_attention # type: bool @@ -598,7 +604,10 @@ class SelfAttention(nn.Module): assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = self.out_proj(attn_output) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if need_weights: + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + else: + attn_weights = None return attn_output, attn_weights def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):