[BART] Remove unused kwargs (#3279)
* Remove unused kwargs * dont call forward in tests
This commit is contained in:
@@ -223,9 +223,7 @@ class EncoderLayer(nn.Module):
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x, attn_weights = self.self_attn(
|
||||
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions,
|
||||
)
|
||||
x, attn_weights = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask,)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
@@ -378,7 +376,7 @@ class DecoderLayer(nn.Module):
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
|
||||
query=x, key=y, value=y, layer_state=layer_state, attn_mask=attention_mask,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
@@ -393,7 +391,6 @@ class DecoderLayer(nn.Module):
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
static_kv=True,
|
||||
need_weights=False, # not returning it so why compute it
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
@@ -548,16 +545,12 @@ class SelfAttention(nn.Module):
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
encoder_decoder_attention=False, # otherwise self_attention
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
@@ -566,13 +559,8 @@ class SelfAttention(nn.Module):
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim # True for all BART
|
||||
|
||||
assert self.encoder_decoder_attention or qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
|
||||
@@ -587,7 +575,6 @@ class SelfAttention(nn.Module):
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = False,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
@@ -598,8 +585,6 @@ class SelfAttention(nn.Module):
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
|
||||
Reference in New Issue
Block a user