From e4b46d86ce0cbcbc9011375add7f3713eb5ef967 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 19 Jun 2019 22:16:30 +0200 Subject: [PATCH] update head pruning --- examples/bertology.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/examples/bertology.py b/examples/bertology.py index 7efdfa459a..8c3fcfc443 100644 --- a/examples/bertology.py +++ b/examples/bertology.py @@ -92,7 +92,13 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, # Normalize attn_entropy /= tot_tokens head_importance /= tot_tokens - if args.normalize_importance: + # Layerwise importance normalization + if not args.dont_normalize_importance_by_layer: + exponent = 2 + norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent) + head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20 + + if not args.dont_normalize_global_importance: head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min()) return attn_entropy, head_importance, preds, labels @@ -106,7 +112,8 @@ def run_model(): parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.") parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory") - parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1") + parser.add_argument("--dont_normalize_importance_by_layer", action='store_true', help="Don't normalize importance score by layers") + parser.add_argument("--dont_normalize_global_importance", action='store_true', help="Don't normalize all importance scores between 0 and 1") parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.") parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics" @@ -243,21 +250,20 @@ def run_model(): current_score = original_score while current_score >= original_score * args.masking_threshold: - head_mask = new_head_mask - # heads from most important to least - heads_to_mask = head_importance.view(-1).sort(descending=True)[1] - # keep only not-masked heads - heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0] + head_mask = new_head_mask # save current head mask + # heads from most important to least - keep only not-masked heads + head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0] + current_heads_to_mask = head_importance.sort()[1] - if len(heads_to_mask) <= num_to_mask: + if len(current_heads_to_mask) <= num_to_mask: break # mask heads - heads_to_mask = heads_to_mask[-num_to_mask:] - logger.info("Heads to mask: %s", str(heads_to_mask.tolist())) + current_heads_to_mask = current_heads_to_mask[:num_to_mask] + logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist())) new_head_mask = head_mask.view(-1) - new_head_mask[heads_to_mask] = 0.0 - new_head_mask = new_head_mask.view_as(head_importance) + new_head_mask[current_heads_to_mask] = 0.0 + new_head_mask = new_head_mask.view_as(head_mask) print_2d_tensor(new_head_mask) # Compute metric and head importance again