Fixed typo
This commit is contained in:
committed by
Julien Chaumond
parent
5d3b8daad2
commit
8e5d84fcc1
@@ -278,7 +278,7 @@ class BertAttention(nn.Module):
|
|||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
# Compute how many pruned heads are before the head and move the index accordingly
|
# Compute how many pruned heads are before the head and move the index accordingly
|
||||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
|
|||||||
Reference in New Issue
Block a user