From 0cd283522ab46a9c1c50576be4fd309c08974d8e Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 27 Aug 2019 15:56:59 -0400 Subject: [PATCH] Attempt to fix head index --- pytorch_transformers/modeling_gpt2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 8aa5347c71..8b39ad372e 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -233,12 +233,14 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = [] def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -249,6 +251,7 @@ class Attention(nn.Module): # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) + self.pruned_heads.extend(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k)