Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -218,9 +218,9 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
original_time = datetime.now() - before_time
|
||||
|
||||
original_num_params = sum(p.numel() for p in model.parameters())
|
||||
heads_to_prune = dict(
|
||||
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
|
||||
)
|
||||
heads_to_prune = {
|
||||
layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
|
||||
}
|
||||
|
||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||
model.prune_heads(heads_to_prune)
|
||||
|
||||
@@ -194,9 +194,9 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
original_time = datetime.now() - before_time
|
||||
|
||||
original_num_params = sum(p.numel() for p in model.parameters())
|
||||
heads_to_prune = dict(
|
||||
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
|
||||
)
|
||||
heads_to_prune = {
|
||||
layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
|
||||
}
|
||||
|
||||
for k, v in heads_to_prune.items():
|
||||
if isinstance(v, int):
|
||||
|
||||
Reference in New Issue
Block a user