more logging
This commit is contained in:
@@ -227,7 +227,7 @@ def run_model():
|
|||||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
logger.info("Pruning: original score: %f", original_score)
|
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
||||||
|
|
||||||
new_head_mask = torch.ones_like(head_importance)
|
new_head_mask = torch.ones_like(head_importance)
|
||||||
num_to_mask = int(new_head_mask.numel() * args.masking_amount)
|
num_to_mask = int(new_head_mask.numel() * args.masking_amount)
|
||||||
@@ -245,6 +245,7 @@ def run_model():
|
|||||||
|
|
||||||
# mask heads
|
# mask heads
|
||||||
heads_to_mask = heads_to_mask[-num_to_mask:]
|
heads_to_mask = heads_to_mask[-num_to_mask:]
|
||||||
|
logger.info("Heads to mask: %s", str(heads_to_mask.tolist()))
|
||||||
new_head_mask = head_mask.view(-1)
|
new_head_mask = head_mask.view(-1)
|
||||||
new_head_mask[heads_to_mask] = 0.0
|
new_head_mask[heads_to_mask] = 0.0
|
||||||
new_head_mask = new_head_mask.view_as(head_importance)
|
new_head_mask = new_head_mask.view_as(head_importance)
|
||||||
@@ -254,7 +255,7 @@ def run_model():
|
|||||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
|
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
|
||||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100)
|
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
|
||||||
|
|
||||||
# Try pruning and test time speedup
|
# Try pruning and test time speedup
|
||||||
# Pruning is like masking but we actually remove the masked weights
|
# Pruning is like masking but we actually remove the masked weights
|
||||||
|
|||||||
Reference in New Issue
Block a user