Black 20 release
This commit is contained in:
@@ -66,9 +66,9 @@ def print_2d_tensor(tensor):
|
||||
def compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
||||
):
|
||||
""" This method shows how to compute:
|
||||
- head attention entropy
|
||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||
"""This method shows how to compute:
|
||||
- head attention entropy
|
||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||
"""
|
||||
# Prepare our tensors
|
||||
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
||||
@@ -150,8 +150,8 @@ def compute_heads_importance(
|
||||
|
||||
|
||||
def mask_heads(args, model, eval_dataloader):
|
||||
""" This method shows how to mask head (set some heads to zero), to test the effect on the network,
|
||||
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""This method shows how to mask head (set some heads to zero), to test the effect on the network,
|
||||
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
@@ -201,8 +201,8 @@ def mask_heads(args, model, eval_dataloader):
|
||||
|
||||
|
||||
def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
""" This method shows how to prune head (remove heads weights) based on
|
||||
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""This method shows how to prune head (remove heads weights) based on
|
||||
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
# Try pruning and test time speedup
|
||||
# Pruning is like masking but we actually remove the masked weights
|
||||
@@ -395,7 +395,8 @@ def main():
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir,
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
|
||||
Reference in New Issue
Block a user