From 7766ce66dd788a780a54f890711ad58cb3d96f5f Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 19 Jun 2019 22:29:51 +0200 Subject: [PATCH] update bertology --- examples/bertology.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/bertology.py b/examples/bertology.py index 8c3fcfc443..bf32e8e174 100644 --- a/examples/bertology.py +++ b/examples/bertology.py @@ -281,9 +281,11 @@ def run_model(): score_masking = compute_metrics(task_name, preds, labels)[args.metric_name] original_time = datetime.now() - before_time + original_num_params = sum(p.numel() for p in model.parameters()) heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask))) assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() model.bert.prune_heads(heads_to_prune) + pruned_num_params = sum(p.numel() for p in model.parameters()) before_time = datetime.now() _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader, @@ -292,6 +294,7 @@ def run_model(): score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name] new_time = datetime.now() - before_time + logger.info("Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", original_num_params, pruned_num_params, pruned_num_params/original_num_params * 100) logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning) logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)