Fixed typo

This commit is contained in:
v_sboliu
2019-11-26 12:04:31 +08:00
committed by Julien Chaumond
parent 5d3b8daad2
commit 8e5d84fcc1

View File

@@ -278,7 +278,7 @@ class BertAttention(nn.Module):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)