Re-styling in seq2seq attention (#11613)
This commit is contained in:
@@ -210,28 +210,26 @@ class BartAttention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -249,18 +247,15 @@ class BartAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -211,28 +211,26 @@ class BlenderbotAttention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -250,18 +248,15 @@ class BlenderbotAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -209,28 +209,26 @@ class BlenderbotSmallAttention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -248,18 +246,15 @@ class BlenderbotSmallAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -280,28 +280,26 @@ class M2M100Attention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -319,18 +317,15 @@ class M2M100Attention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -226,28 +226,26 @@ class MarianAttention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -265,18 +263,15 @@ class MarianAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -217,28 +217,26 @@ class MBartAttention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -256,18 +254,15 @@ class MBartAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -226,28 +226,26 @@ class PegasusAttention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -265,18 +263,15 @@ class PegasusAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -293,28 +293,26 @@ class Speech2TextAttention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -332,18 +330,15 @@ class Speech2TextAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -356,28 +356,26 @@ class Wav2Vec2Attention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -395,18 +393,15 @@ class Wav2Vec2Attention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
@@ -1721,28 +1721,26 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
|||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -1760,18 +1758,15 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|||||||
Reference in New Issue
Block a user