From 0c8e823b031d99d06bddff2b88fd4da2d7500117 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 29 Aug 2019 17:20:11 -0400 Subject: [PATCH] Added patch to remaining models --- pytorch_transformers/modeling_bert.py | 3 +++ pytorch_transformers/modeling_openai.py | 3 +++ pytorch_transformers/modeling_xlm.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 5a65e442d0..9aa25edbe3 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -337,12 +337,14 @@ class BertAttention(nn.Module): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) + self.pruned_heads = [] def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -354,6 +356,7 @@ class BertAttention(nn.Module): # Update hyper params self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads.extend(heads) def forward(self, input_tensor, attention_mask, head_mask=None): self_outputs = self.self(input_tensor, attention_mask, head_mask) diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index ce3768c676..78e57b0c59 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -249,12 +249,14 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = [] def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -265,6 +267,7 @@ class Attention(nn.Module): # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) + self.pruned_heads.extend(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 1e0f8d7c77..17e39528f8 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module): self.k_lin = nn.Linear(dim, dim) self.v_lin = nn.Linear(dim, dim) self.out_lin = nn.Linear(dim, dim) + self.pruned_heads = [] def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads @@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module): return mask = torch.ones(self.n_heads, attention_head_size) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module): # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = attention_head_size * self.n_heads + self.pruned_heads.extend(heads) def forward(self, input, mask, kv=None, cache=None, head_mask=None): """