[cleanup] consolidate some prune_heads logic (#4799)

This commit is contained in:
Sam Shleifer
2020-06-08 17:08:04 -04:00
committed by GitHub
parent 4c7f564f9a
commit a139d1a160
8 changed files with 54 additions and 59 deletions

View File

@@ -55,6 +55,20 @@ except ImportError:
return input
def find_pruneable_heads_and_indices(
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
) -> Tuple[set, "torch.LongTensor"]:
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index
class ModuleUtilsMixin:
"""
A few utilities for torch.nn.Modules, to be used as a mixin.