more logging
This commit is contained in:
@@ -227,7 +227,7 @@ def run_model():
|
||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||
logger.info("Pruning: original score: %f", original_score)
|
||||
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
||||
|
||||
new_head_mask = torch.ones_like(head_importance)
|
||||
num_to_mask = int(new_head_mask.numel() * args.masking_amount)
|
||||
@@ -245,6 +245,7 @@ def run_model():
|
||||
|
||||
# mask heads
|
||||
heads_to_mask = heads_to_mask[-num_to_mask:]
|
||||
logger.info("Heads to mask: %s", str(heads_to_mask.tolist()))
|
||||
new_head_mask = head_mask.view(-1)
|
||||
new_head_mask[heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_importance)
|
||||
@@ -254,7 +255,7 @@ def run_model():
|
||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||
logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100)
|
||||
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
|
||||
|
||||
# Try pruning and test time speedup
|
||||
# Pruning is like masking but we actually remove the masked weights
|
||||
|
||||
Reference in New Issue
Block a user