[cleanup] consolidate some prune_heads logic (#4799)
This commit is contained in:
@@ -55,6 +55,20 @@ except ImportError:
|
||||
return input
|
||||
|
||||
|
||||
def find_pruneable_heads_and_indices(
|
||||
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
|
||||
) -> Tuple[set, "torch.LongTensor"]:
|
||||
mask = torch.ones(n_heads, head_size)
|
||||
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
||||
return heads, index
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for torch.nn.Modules, to be used as a mixin.
|
||||
|
||||
Reference in New Issue
Block a user