From edfc8f822557f3df7d9057a6457a933cddf15299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 10:17:27 +0200 Subject: [PATCH] Remove and do the branching in --- transformers/modeling_bert.py | 44 ++--------------------------------- 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 94791571cd..89407ff8ab 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -282,53 +282,13 @@ class BertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, attention_mask=None, head_mask=None): - self_outputs = self.self(hidden_states, attention_mask, head_mask) + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): + self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs -class BertDecoderAttention(nn.Module): - def __init__(self, config): - super(BertAttention, self).__init__() - self.self = BertGeneralAttention(config) - self.output = BertSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) - heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward(self, query, key, value, attention_mask=None, head_mask=None): - self_outputs = self.self(query, key, value, attention_mask, head_mask) - # in encoder-decoder attention we use the output of the previous decoder stage as the query - # in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput. - # This shows the limits of the current code architecture, which may benefit from some refactoring. - attention_output = self.output(self_outputs[0], query) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - class BertIntermediate(nn.Module): def __init__(self, config): super(BertIntermediate, self).__init__()