first version bertology ok
This commit is contained in:
@@ -25,17 +25,20 @@ def entropy(p):
|
|||||||
plogp[p == 0] = 0
|
plogp[p == 0] = 0
|
||||||
return -plogp.sum(dim=-1)
|
return -plogp.sum(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def print_1d_tensor(tensor, prefix=""):
|
def print_1d_tensor(tensor, prefix=""):
|
||||||
if tensor.dtype != torch.long:
|
if tensor.dtype != torch.long:
|
||||||
logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data))
|
logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data))
|
||||||
else:
|
else:
|
||||||
logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data))
|
logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data))
|
||||||
|
|
||||||
|
|
||||||
def print_2d_tensor(tensor):
|
def print_2d_tensor(tensor):
|
||||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||||
for row in range(len(tensor)):
|
for row in range(len(tensor)):
|
||||||
print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")
|
print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")
|
||||||
|
|
||||||
|
|
||||||
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
|
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
|
||||||
""" Example on how to use model outputs to compute:
|
""" Example on how to use model outputs to compute:
|
||||||
- head attention entropy (activated by setting output_attentions=True when we created the model
|
- head attention entropy (activated by setting output_attentions=True when we created the model
|
||||||
@@ -54,7 +57,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
|
||||||
# Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below)
|
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
||||||
all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)
|
all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)
|
||||||
|
|
||||||
if compute_entropy:
|
if compute_entropy:
|
||||||
@@ -103,6 +106,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
|||||||
|
|
||||||
return attn_entropy, head_importance, preds, labels
|
return attn_entropy, head_importance, preds, labels
|
||||||
|
|
||||||
|
|
||||||
def run_model():
|
def run_model():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint')
|
parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint')
|
||||||
@@ -212,7 +216,7 @@ def run_model():
|
|||||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
|
|
||||||
if args.data_subset > 0:
|
if args.data_subset > 0:
|
||||||
eval_data = Subset(eval_data, list(range(args.data_subset)))
|
eval_data = Subset(eval_data, list(range(min(args.data_subset, len(eval_data)))))
|
||||||
|
|
||||||
eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
|
eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
|
||||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
|
||||||
@@ -246,14 +250,14 @@ def run_model():
|
|||||||
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
||||||
|
|
||||||
new_head_mask = torch.ones_like(head_importance)
|
new_head_mask = torch.ones_like(head_importance)
|
||||||
num_to_mask = int(new_head_mask.numel() * args.masking_amount)
|
num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
|
||||||
|
|
||||||
current_score = original_score
|
current_score = original_score
|
||||||
while current_score >= original_score * args.masking_threshold:
|
while current_score >= original_score * args.masking_threshold:
|
||||||
head_mask = new_head_mask # save current head mask
|
head_mask = new_head_mask.clone() # save current head mask
|
||||||
# heads from most important to least - keep only not-masked heads
|
# heads from least important to most - keep only not-masked heads
|
||||||
head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0]
|
head_importance[head_mask == 0.0] = float('Inf')
|
||||||
current_heads_to_mask = head_importance.sort()[1]
|
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
||||||
|
|
||||||
if len(current_heads_to_mask) <= num_to_mask:
|
if len(current_heads_to_mask) <= num_to_mask:
|
||||||
break
|
break
|
||||||
@@ -261,7 +265,7 @@ def run_model():
|
|||||||
# mask heads
|
# mask heads
|
||||||
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
||||||
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
||||||
new_head_mask = head_mask.view(-1)
|
new_head_mask = new_head_mask.view(-1)
|
||||||
new_head_mask[current_heads_to_mask] = 0.0
|
new_head_mask[current_heads_to_mask] = 0.0
|
||||||
new_head_mask = new_head_mask.view_as(head_mask)
|
new_head_mask = new_head_mask.view_as(head_mask)
|
||||||
print_2d_tensor(new_head_mask)
|
print_2d_tensor(new_head_mask)
|
||||||
@@ -272,6 +276,10 @@ def run_model():
|
|||||||
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
|
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
|
||||||
|
|
||||||
|
logger.info("Final head mask")
|
||||||
|
print_2d_tensor(head_mask)
|
||||||
|
np.save(os.path.join(args.output_dir, 'head_mask.npy'), head_mask.detach().cpu().numpy())
|
||||||
|
|
||||||
# Try pruning and test time speedup
|
# Try pruning and test time speedup
|
||||||
# Pruning is like masking but we actually remove the masked weights
|
# Pruning is like masking but we actually remove the masked weights
|
||||||
before_time = datetime.now()
|
before_time = datetime.now()
|
||||||
|
|||||||
Reference in New Issue
Block a user