From 291974c65cb9d93587826f80ed7d401f2bab2de7 Mon Sep 17 00:00:00 2001 From: Aleksey Tikhonov Date: Fri, 18 Dec 2020 22:32:10 +0100 Subject: [PATCH] GPT-model attention heads pruning example (#9189) * Pruning for GPT attn heads * The code formatted according to the transformers requirements * Update run_prune_gpt.py * Update run_prune_gpt.py --- .../bertology/run_prune_gpt.py | 388 ++++++++++++++++++ 1 file changed, 388 insertions(+) create mode 100644 examples/research_projects/bertology/run_prune_gpt.py diff --git a/examples/research_projects/bertology/run_prune_gpt.py b/examples/research_projects/bertology/run_prune_gpt.py new file mode 100644 index 0000000000..e89a6a4fdc --- /dev/null +++ b/examples/research_projects/bertology/run_prune_gpt.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +""" This script is adapted from the Bertology pruning code (https://github.com/huggingface/transformers/blob/783d7d2629e97c5f0c5f9ef01b8c66410275c204/examples/research_projects/bertology/run_bertology.py) +to prune GPT-like models. The author is @altsoph. +""" + +import argparse +import logging +import os +from datetime import datetime + +import numpy as np +import torch +from torch.utils.data import DataLoader, RandomSampler, TensorDataset +from tqdm import tqdm + +from transformers import GPT2LMHeadModel + + +logger = logging.getLogger(__name__) + + +def save_model(model, dirpath): + # save results + if os.path.exists(dirpath): + if os.path.exists(os.path.join(dirpath, "config.json")) and os.path.isfile( + os.path.join(dirpath, "config.json") + ): + os.remove(os.path.join(dirpath, "config.json")) + if os.path.exists(os.path.join(dirpath, "pytorch_model.bin")) and os.path.isfile( + os.path.join(dirpath, "pytorch_model.bin") + ): + os.remove(os.path.join(dirpath, "pytorch_model.bin")) + else: + os.makedirs(dirpath) + model.save_pretrained(dirpath) + + +def entropy(p, unlogit=False): + """ Compute the entropy of a probability distribution """ + exponent = 2 + if unlogit: + p = torch.pow(p, exponent) + plogp = p * torch.log(p) + plogp[p == 0] = 0 + return -plogp.sum(dim=-1) + + +def print_2d_tensor(tensor): + """ Print a 2D tensor """ + logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor)))) + for row in range(len(tensor)): + if tensor.dtype != torch.long: + logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data)) + else: + logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data)) + + +def compute_heads_importance( + args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False +): + """This method shows how to compute: + - head attention entropy + - head importance scores according to http://arxiv.org/abs/1905.10650 + """ + # Prepare our tensors + n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads + head_importance = torch.zeros(n_layers, n_heads).to(args.device) + attn_entropy = torch.zeros(n_layers, n_heads).to(args.device) + + if head_mask is None: + head_mask = torch.ones(n_layers, n_heads).to(args.device) + + head_mask.requires_grad_(requires_grad=True) + # If actually pruned attention multi-head, set head mask to None to avoid shape mismatch + if actually_pruned: + head_mask = None + + tot_tokens = 0.0 + total_loss = 0.0 + for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): + inputs = tuple(t.to(args.device) for t in inputs) + (input_ids,) = inputs + + # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below) + outputs = model(input_ids, labels=input_ids, head_mask=head_mask) + # (loss), lm_logits, presents, (all hidden_states), (attentions) + loss, _, all_attentions = ( + outputs[0], + outputs[1], + outputs[-1], + ) # Loss and logits are the first, attention the last + loss.backward() # Backpropagate to populate the gradients in the head mask + total_loss += loss.detach().cpu().numpy() + if compute_entropy: + for layer, attn in enumerate(all_attentions): + masked_entropy = entropy(attn.detach(), True) + attn_entropy[layer] += masked_entropy.sum(-1).sum(0).sum(0).detach() + + if compute_importance: + head_importance += head_mask.grad.abs().detach() + tot_tokens += torch.ones_like(input_ids).float().detach().sum().data + + # Normalize + attn_entropy /= tot_tokens + head_importance /= tot_tokens + # Layerwise importance normalization + if not args.dont_normalize_importance_by_layer: + exponent = 2 + norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent) + head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20 + + if not args.dont_normalize_global_importance: + head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min()) + + # Print matrices + if compute_entropy: + logger.info("Attention entropies") + print_2d_tensor(attn_entropy) + if compute_importance: + logger.info("Head importance scores") + print_2d_tensor(head_importance) + logger.info("Head ranked by importance scores") + head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device) + head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange( + head_importance.numel(), device=args.device + ) + head_ranks = head_ranks.view_as(head_importance) + print_2d_tensor(head_ranks) + return attn_entropy, head_importance, total_loss + + +def mask_heads(args, model, eval_dataloader): + """This method shows how to mask head (set some heads to zero), to test the effect on the network, + based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650) + """ + _, head_importance, loss = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False) + original_score = 1 / loss # instead of downsteam score use the LM loss + logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold) + + new_head_mask = torch.ones_like(head_importance) + num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount)) + + current_score = original_score + while current_score >= original_score * args.masking_threshold: + head_mask = new_head_mask.clone().detach() # save current head mask + # heads from least important to most - keep only not-masked heads + head_importance[head_mask == 0.0] = float("Inf") + current_heads_to_mask = head_importance.view(-1).sort()[1] + + if len(current_heads_to_mask) <= num_to_mask: + print("BREAK BY num_to_mask") + break + + # mask heads + current_heads_to_mask = current_heads_to_mask[:num_to_mask] + logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist())) + new_head_mask = new_head_mask.view(-1) + new_head_mask[current_heads_to_mask] = 0.0 + new_head_mask = new_head_mask.view_as(head_mask) + new_head_mask = new_head_mask.clone().detach() + print_2d_tensor(new_head_mask) + + # Compute metric and head importance again + _, head_importance, loss = compute_heads_importance( + args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask + ) + current_score = 1 / loss + logger.info( + "Masking: current score: %f, remaining heads %d (%.1f percents)", + current_score, + new_head_mask.sum(), + new_head_mask.sum() / new_head_mask.numel() * 100, + ) + + logger.info("Final head mask") + print_2d_tensor(head_mask) + np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy()) + + return head_mask + + +def prune_heads(args, model, eval_dataloader, head_mask): + """This method shows how to prune head (remove heads weights) based on + the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650) + """ + # Try pruning and test time speedup + # Pruning is like masking but we actually remove the masked weights + before_time = datetime.now() + _, _, loss = compute_heads_importance( + args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask + ) + score_masking = 1 / loss + original_time = datetime.now() - before_time + + original_num_params = sum(p.numel() for p in model.parameters()) + heads_to_prune = dict( + (layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask)) + ) + + for k, v in heads_to_prune.items(): + if isinstance(v, int): + heads_to_prune[k] = [ + v, + ] + + assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() + model.prune_heads(heads_to_prune) + pruned_num_params = sum(p.numel() for p in model.parameters()) + + before_time = datetime.now() + _, _, loss = compute_heads_importance( + args, + model, + eval_dataloader, + compute_entropy=False, + compute_importance=False, + head_mask=None, + actually_pruned=True, + ) + + score_pruning = 1 / loss + new_time = datetime.now() - before_time + + logger.info( + "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", + original_num_params, + pruned_num_params, + pruned_num_params / original_num_params * 100, + ) + logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning) + logger.info("Pruning: speed ratio (original timing / new timing): %f percents", original_time / new_time * 100) + save_model(model, args.output_dir) + + +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--data_dir", + default=None, + type=str, + required=True, + help="The input data dir. Should contain the .tsv files (or other data files) for the task.", + ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models", + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + + # Other parameters + parser.add_argument( + "--config_name", + default="", + type=str, + help="Pretrained config name or path if not the same as model_name_or_path", + ) + parser.add_argument( + "--tokenizer_name", + default="", + type=str, + help="Pretrained tokenizer name or path if not the same as model_name_or_path", + ) + parser.add_argument( + "--cache_dir", + default=None, + type=str, + help="Where do you want to store the pre-trained models downloaded from s3", + ) + parser.add_argument( + "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances." + ) + parser.add_argument( + "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory" + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + + parser.add_argument( + "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers" + ) + parser.add_argument( + "--dont_normalize_global_importance", + action="store_true", + help="Don't normalize all importance scores between 0 and 1", + ) + + parser.add_argument( + "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy." + ) + parser.add_argument( + "--masking_threshold", + default=0.9, + type=float, + help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).", + ) + parser.add_argument( + "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step." + ) + parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.") + + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after WordPiece tokenization. \n" + "Sequences longer than this will be truncated, sequences shorter padded.", + ) + parser.add_argument("--batch_size", default=1, type=int, help="Batch size.") + + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") + parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") + parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") + args = parser.parse_args() + + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + # Setup devices and distributed training + if args.local_rank == -1 or args.no_cuda: + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() + else: + torch.cuda.set_device(args.local_rank) + args.device = torch.device("cuda", args.local_rank) + args.n_gpu = 1 + torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend + + # Setup logging + logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1))) + + model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) + + # Distributed and parallel training + model.to(args.device) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True + ) + elif args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Print/save training arguments + os.makedirs(args.output_dir, exist_ok=True) + torch.save(args, os.path.join(args.output_dir, "run_args.bin")) + logger.info("Training/evaluation parameters %s", args) + + # Prepare dataset + numpy_data = np.concatenate( + [ + np.loadtxt(args.data_dir, dtype=np.int64), + ] + ) + train_tensor_dataset = (torch.from_numpy(numpy_data),) + train_data = TensorDataset(*train_tensor_dataset) + train_sampler = RandomSampler(train_data) + eval_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) + + # Compute head entropy and importance score + compute_heads_importance(args, model, eval_dataloader) + + # Try head masking (set heads to zero until the score goes under a threshole) + # and head pruning (remove masked heads and see the effect on the network) + if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0: + head_mask = mask_heads(args, model, eval_dataloader) + prune_heads(args, model, eval_dataloader, head_mask) + + +if __name__ == "__main__": + main()