update bertology
This commit is contained in:
@@ -281,9 +281,11 @@ def run_model():
|
|||||||
score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
|
score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
original_time = datetime.now() - before_time
|
original_time = datetime.now() - before_time
|
||||||
|
|
||||||
|
original_num_params = sum(p.numel() for p in model.parameters())
|
||||||
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
|
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
|
||||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||||
model.bert.prune_heads(heads_to_prune)
|
model.bert.prune_heads(heads_to_prune)
|
||||||
|
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||||
|
|
||||||
before_time = datetime.now()
|
before_time = datetime.now()
|
||||||
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
||||||
@@ -292,6 +294,7 @@ def run_model():
|
|||||||
score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
|
score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
new_time = datetime.now() - before_time
|
new_time = datetime.now() - before_time
|
||||||
|
|
||||||
|
logger.info("Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", original_num_params, pruned_num_params, pruned_num_params/original_num_params * 100)
|
||||||
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||||
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
|
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user