From 271bedb4857c75fd6183fcf9f479913469f40fef Mon Sep 17 00:00:00 2001 From: Tobias Lee Date: Thu, 21 May 2020 21:17:03 +0800 Subject: [PATCH] [examples] fix no grad in second pruning in run_bertology (#4479) * fix no grad in second pruning and typo * fix prune heads attention mismatch problem * fix * fix * fix * run make style * run make style --- examples/bertology/run_bertology.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/bertology/run_bertology.py b/examples/bertology/run_bertology.py index 8d26bf890c..331da73954 100644 --- a/examples/bertology/run_bertology.py +++ b/examples/bertology/run_bertology.py @@ -64,7 +64,7 @@ def print_2d_tensor(tensor): def compute_heads_importance( - args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None + args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False ): """ This method shows how to compute: - head attention entropy @@ -77,7 +77,12 @@ def compute_heads_importance( if head_mask is None: head_mask = torch.ones(n_layers, n_heads).to(args.device) + head_mask.requires_grad_(requires_grad=True) + # If actually pruned attention multi-head, set head mask to None to avoid shape mismatch + if actually_pruned: + head_mask = None + preds = None labels = None tot_tokens = 0.0 @@ -172,6 +177,7 @@ def mask_heads(args, model, eval_dataloader): new_head_mask = new_head_mask.view(-1) new_head_mask[current_heads_to_mask] = 0.0 new_head_mask = new_head_mask.view_as(head_mask) + new_head_mask = new_head_mask.clone().detach() print_2d_tensor(new_head_mask) # Compute metric and head importance again @@ -181,7 +187,7 @@ def mask_heads(args, model, eval_dataloader): preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name] logger.info( - "Masking: current score: %f, remaning heads %d (%.1f percents)", + "Masking: current score: %f, remaining heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum() / new_head_mask.numel() * 100, @@ -209,14 +215,23 @@ def prune_heads(args, model, eval_dataloader, head_mask): original_time = datetime.now() - before_time original_num_params = sum(p.numel() for p in model.parameters()) - heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask))) + heads_to_prune = dict( + (layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask)) + ) + assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() model.prune_heads(heads_to_prune) pruned_num_params = sum(p.numel() for p in model.parameters()) before_time = datetime.now() _, _, preds, labels = compute_heads_importance( - args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None + args, + model, + eval_dataloader, + compute_entropy=False, + compute_importance=False, + head_mask=None, + actually_pruned=True, ) preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]