updated pruning logic with sets - Bert and GPT-2
This commit is contained in:
@@ -233,25 +233,29 @@ class Attention(nn.Module):
|
||||
self.c_proj = Conv1D(n_state, nx)
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.pruned_heads = []
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||
for head in heads:
|
||||
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
|
||||
|
||||
# Prune conv1d layers
|
||||
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
||||
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
||||
|
||||
# Update hyper params
|
||||
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
||||
self.n_head = self.n_head - len(heads)
|
||||
self.pruned_heads.extend(heads)
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def _attn(self, q, k, v, head_mask=None):
|
||||
w = torch.matmul(q, k)
|
||||
@@ -357,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
|
||||
def init_weights(self, module):
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
||||
@@ -456,14 +460,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
||||
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
for layer, heads in pruned_heads:
|
||||
if self.h[int(layer)].attn.n_head == config.n_head:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
|
||||
@@ -594,7 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
self.transformer = GPT2Model(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
@@ -718,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
|
||||
Reference in New Issue
Block a user