Remove and do the branching in
This commit is contained in:
@@ -282,53 +282,13 @@ class BertAttention(nn.Module):
|
|||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
self.pruned_heads = self.pruned_heads.union(heads)
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||||
self_outputs = self.self(hidden_states, attention_mask, head_mask)
|
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class BertDecoderAttention(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super(BertAttention, self).__init__()
|
|
||||||
self.self = BertGeneralAttention(config)
|
|
||||||
self.output = BertSelfOutput(config)
|
|
||||||
self.pruned_heads = set()
|
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
|
||||||
if len(heads) == 0:
|
|
||||||
return
|
|
||||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
|
||||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
|
||||||
for head in heads:
|
|
||||||
# Compute how many pruned heads are before the head and move the index accordingly
|
|
||||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
|
|
||||||
# Prune linear layers
|
|
||||||
self.self.query = prune_linear_layer(self.self.query, index)
|
|
||||||
self.self.key = prune_linear_layer(self.self.key, index)
|
|
||||||
self.self.value = prune_linear_layer(self.self.value, index)
|
|
||||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
||||||
|
|
||||||
# Update hyper params and store pruned heads
|
|
||||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
|
||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
|
||||||
self.pruned_heads = self.pruned_heads.union(heads)
|
|
||||||
|
|
||||||
def forward(self, query, key, value, attention_mask=None, head_mask=None):
|
|
||||||
self_outputs = self.self(query, key, value, attention_mask, head_mask)
|
|
||||||
# in encoder-decoder attention we use the output of the previous decoder stage as the query
|
|
||||||
# in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput.
|
|
||||||
# This shows the limits of the current code architecture, which may benefit from some refactoring.
|
|
||||||
attention_output = self.output(self_outputs[0], query)
|
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class BertIntermediate(nn.Module):
|
class BertIntermediate(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertIntermediate, self).__init__()
|
super(BertIntermediate, self).__init__()
|
||||||
|
|||||||
Reference in New Issue
Block a user