From 0e1e8128bf98d1154712d5725b8c44090821d26c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 19 Jun 2019 15:35:49 +0200 Subject: [PATCH] more logging --- examples/bertology.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/bertology.py b/examples/bertology.py index 694328162a..af7e661a52 100644 --- a/examples/bertology.py +++ b/examples/bertology.py @@ -227,7 +227,7 @@ def run_model(): _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False) preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) original_score = compute_metrics(task_name, preds, labels)[args.metric_name] - logger.info("Pruning: original score: %f", original_score) + logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold) new_head_mask = torch.ones_like(head_importance) num_to_mask = int(new_head_mask.numel() * args.masking_amount) @@ -245,6 +245,7 @@ def run_model(): # mask heads heads_to_mask = heads_to_mask[-num_to_mask:] + logger.info("Heads to mask: %s", str(heads_to_mask.tolist())) new_head_mask = head_mask.view(-1) new_head_mask[heads_to_mask] = 0.0 new_head_mask = new_head_mask.view_as(head_importance) @@ -254,7 +255,7 @@ def run_model(): _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask) preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) current_score = compute_metrics(task_name, preds, labels)[args.metric_name] - logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100) + logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100) # Try pruning and test time speedup # Pruning is like masking but we actually remove the masked weights