Fixed typo
This commit is contained in:
committed by
Julien Chaumond
parent
5d3b8daad2
commit
8e5d84fcc1
@@ -278,7 +278,7 @@ class BertAttention(nn.Module):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
|
||||
Reference in New Issue
Block a user