more logging

This commit is contained in:
thomwolf
2019-06-19 15:35:49 +02:00
parent 909d4f1af2
commit 0e1e8128bf

View File

@@ -227,7 +227,7 @@ def run_model():
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False) _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
original_score = compute_metrics(task_name, preds, labels)[args.metric_name] original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
logger.info("Pruning: original score: %f", original_score) logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
new_head_mask = torch.ones_like(head_importance) new_head_mask = torch.ones_like(head_importance)
num_to_mask = int(new_head_mask.numel() * args.masking_amount) num_to_mask = int(new_head_mask.numel() * args.masking_amount)
@@ -245,6 +245,7 @@ def run_model():
# mask heads # mask heads
heads_to_mask = heads_to_mask[-num_to_mask:] heads_to_mask = heads_to_mask[-num_to_mask:]
logger.info("Heads to mask: %s", str(heads_to_mask.tolist()))
new_head_mask = head_mask.view(-1) new_head_mask = head_mask.view(-1)
new_head_mask[heads_to_mask] = 0.0 new_head_mask[heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_importance) new_head_mask = new_head_mask.view_as(head_importance)
@@ -254,7 +255,7 @@ def run_model():
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask) _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
current_score = compute_metrics(task_name, preds, labels)[args.metric_name] current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100) logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
# Try pruning and test time speedup # Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights # Pruning is like masking but we actually remove the masked weights