[BART] Remove unused kwargs (#3279)

* Remove unused kwargs
* dont call forward in tests
This commit is contained in:
Sam Shleifer
2020-03-15 23:00:44 -04:00
committed by GitHub
parent 3814e167d9
commit 5ea8ba67b4
3 changed files with 14 additions and 29 deletions

View File

@@ -223,9 +223,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x, attn_weights = self.self_attn(
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions,
)
x, attn_weights = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask,)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
@@ -378,7 +376,7 @@ class DecoderLayer(nn.Module):
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn(
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
query=x, key=y, value=y, layer_state=layer_state, attn_mask=attention_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
@@ -393,7 +391,6 @@ class DecoderLayer(nn.Module):
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
static_kv=True,
need_weights=False, # not returning it so why compute it
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
@@ -548,16 +545,12 @@ class SelfAttention(nn.Module):
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
encoder_decoder_attention=False, # otherwise self_attention
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.num_heads = num_heads
self.dropout = dropout
@@ -566,13 +559,8 @@ class SelfAttention(nn.Module):
self.scaling = self.head_dim ** -0.5
self.encoder_decoder_attention = encoder_decoder_attention
qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim # True for all BART
assert self.encoder_decoder_attention or qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
@@ -587,7 +575,6 @@ class SelfAttention(nn.Module):
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = False,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
@@ -598,8 +585,6 @@ class SelfAttention(nn.Module):
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).