Reorganize examples (#9010)
* Reorganize example folder * Continue reorganization * Change requirements for tests * Final cleanup * Finish regroup with tests all passing * Copyright * Requirements and readme * Make a full link for the documentation * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add symlink * Reorg again * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Adapt title * Update to new strucutre * Remove test * Update READMEs Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
185
examples/research_projects/movement-pruning/README.md
Normal file
185
examples/research_projects/movement-pruning/README.md
Normal file
@@ -0,0 +1,185 @@
|
||||
# Movement Pruning: Adaptive Sparsity by Fine-Tuning
|
||||
|
||||
Author: @VictorSanh
|
||||
|
||||
*Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; however, it is less effective in the transfer learning regime that has become standard for state-of-the-art natural language processing applications. We propose the use of *movement pruning*, a simple, deterministic first-order weight pruning method that is more adaptive to pretrained model fine-tuning. Experiments show that when pruning large pretrained language models, movement pruning shows significant improvements in high-sparsity regimes. When combined with distillation, the approach achieves minimal accuracy loss with down to only 3% of the model parameters:*
|
||||
|
||||
| Fine-pruning+Distillation<br>(Teacher=BERT-base fine-tuned) | BERT base<br>fine-tuned | Remaining<br>Weights (%) | Magnitude Pruning | L0 Regularization | Movement Pruning | Soft Movement Pruning |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| SQuAD - Dev<br>EM/F1 | 80.4/88.1 | 10%<br>3% | 70.2/80.1<br>45.5/59.6 | 72.4/81.9<br>64.3/75.8 | 75.6/84.3<br>67.5/78.0 | **76.6/84.9**<br>**72.7/82.3** |
|
||||
| MNLI - Dev<br>acc/MM acc | 84.5/84.9 | 10%<br>3% | 78.3/79.3<br>69.4/70.6 | 78.7/79.7<br>76.0/76.2 | 80.1/80.4<br>76.5/77.4 | **81.2/81.8**<br>**79.5/80.1** |
|
||||
| QQP - Dev<br>acc/F1 | 91.4/88.4 | 10%<br>3% | 79.8/65.0<br>72.4/57.8 | 88.1/82.8<br>87.0/81.9 | 89.7/86.2<br>86.1/81.5 | **90.2/86.8**<br>**89.1/85.5** |
|
||||
|
||||
This page contains information on how to fine-prune pre-trained models such as `BERT` to obtain extremely sparse models with movement pruning. In contrast to magnitude pruning which selects weights that are far from 0, movement pruning retains weights that are moving away from 0.
|
||||
|
||||
For more information, we invite you to check out [our paper](https://arxiv.org/abs/2005.07683).
|
||||
You can also have a look at this fun *Explain Like I'm Five* introductory [slide deck](https://www.slideshare.net/VictorSanh/movement-pruning-explain-like-im-five-234205241).
|
||||
|
||||
<div align="center">
|
||||
<img src="https://www.seekpng.com/png/detail/166-1669328_how-to-make-emmental-cheese-at-home-icooker.png" width="400">
|
||||
</div>
|
||||
|
||||
## Extreme sparsity and efficient storage
|
||||
|
||||
One promise of extreme pruning is to obtain extremely small models that can be easily sent (and stored) on edge devices. By setting weights to 0., we reduce the amount of information we need to store, and thus decreasing the memory size. We are able to obtain extremely sparse fine-pruned models with movement pruning: ~95% of the dense performance with ~5% of total remaining weights in the BERT encoder.
|
||||
|
||||
In [this notebook](https://github.com/huggingface/transformers/blob/master/examples/movement-pruning/Saving_PruneBERT.ipynb), we showcase how we can leverage standard tools that exist out-of-the-box to efficiently store an extremely sparse question answering model (only 6% of total remaining weights in the encoder). We are able to reduce the memory size of the encoder **from the 340MB (the original dense BERT) to 11MB**, without any additional training of the model (every operation is performed *post fine-pruning*). It is sufficiently small to store it on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical) 📎!
|
||||
|
||||
While movement pruning does not directly optimize for memory footprint (but rather the number of non-null weights), we hypothetize that further memory compression ratios can be achieved with specific quantization aware trainings (see for instance [Q8BERT](https://arxiv.org/abs/1910.06188), [And the Bit Goes Down](https://arxiv.org/abs/1907.05686) or [Quant-Noise](https://arxiv.org/abs/2004.07320)).
|
||||
|
||||
## Fine-pruned models
|
||||
|
||||
As examples, we release two English PruneBERT checkpoints (models fine-pruned from a pre-trained `BERT` checkpoint), one on SQuAD and the other on MNLI.
|
||||
|
||||
- **`prunebert-base-uncased-6-finepruned-w-distil-squad`**<br/>
|
||||
Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from `BERT-base-uncased` finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: `pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")`
|
||||
- **`prunebert-base-uncased-6-finepruned-w-distil-mnli`**<br/>
|
||||
Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on MNLI. We use an additional distillation signal from `BERT-base-uncased` finetuned on MNLI. The encoder counts 6% of total non-null weights and reaches 80.7 (matched) accuracy. The model can be accessed with: `pruned_bert = BertForSequenceClassification.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-mnli")`
|
||||
|
||||
## How to fine-prune?
|
||||
|
||||
### Setup
|
||||
|
||||
The code relies on the 🤗 Transformers library. In addition to the dependencies listed in the [`examples`](https://github.com/huggingface/transformers/tree/master/examples) folder, you should install a few additional dependencies listed in the `requirements.txt` file: `pip install -r requirements.txt`.
|
||||
|
||||
Note that we built our experiments on top of a stabilized version of the library (commit https://github.com/huggingface/transformers/commit/352d5472b0c1dec0f420d606d16747d851b4bda8): we do not guarantee that everything is still compatible with the latest version of the master branch.
|
||||
|
||||
### Fine-pruning with movement pruning
|
||||
|
||||
Below, we detail how to reproduce the results reported in the paper. We use SQuAD as a running example. Commands (and scripts) can be easily adapted for other tasks.
|
||||
|
||||
The following command fine-prunes a pre-trained `BERT-base` on SQuAD using movement pruning towards 15% of remaining weights (85% sparsity). Note that we freeze all the embeddings modules (from their pre-trained value) and only prune the Fully Connected layers in the encoder (12 layers of Transformer Block).
|
||||
|
||||
```bash
|
||||
SERIALIZATION_DIR=<OUTPUT_DIR>
|
||||
SQUAD_DATA=<SQUAD_DATA>
|
||||
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir $SERIALIZATION_DIR \
|
||||
--data_dir $SQUAD_DATA \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
|
||||
--initial_threshold 1 --final_threshold 0.15 \
|
||||
--initial_warmup 1 --final_warmup 2 \
|
||||
--pruning_method topK --mask_init constant --mask_scale 0.
|
||||
```
|
||||
|
||||
### Fine-pruning with other methods
|
||||
|
||||
We can also explore other fine-pruning methods by changing the `pruning_method` parameter:
|
||||
|
||||
Soft movement pruning
|
||||
```bash
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir $SERIALIZATION_DIR \
|
||||
--data_dir $SQUAD_DATA \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
|
||||
--initial_threshold 0 --final_threshold 0.1 \
|
||||
--initial_warmup 1 --final_warmup 2 \
|
||||
--pruning_method sigmoied_threshold --mask_init constant --mask_scale 0. \
|
||||
--regularization l1 --final_lambda 400.
|
||||
```
|
||||
|
||||
L0 regularization
|
||||
```bash
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir $SERIALIZATION_DIR \
|
||||
--data_dir $SQUAD_DATA \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 --mask_scores_learning_rate 1e-1 \
|
||||
--initial_threshold 1. --final_threshold 1. \
|
||||
--initial_warmup 1 --final_warmup 1 \
|
||||
--pruning_method l0 --mask_init constant --mask_scale 2.197 \
|
||||
--regularization l0 --final_lambda 125.
|
||||
```
|
||||
|
||||
Iterative Magnitude Pruning
|
||||
```bash
|
||||
python examples/movement-pruning/masked_run_squad.py \
|
||||
--output_dir ./dbg \
|
||||
--data_dir examples/distillation/data/squad_data \
|
||||
--train_file train-v1.1.json \
|
||||
--predict_file dev-v1.1.json \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--model_type masked_bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--warmup_steps 5400 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 3e-5 \
|
||||
--initial_threshold 1 --final_threshold 0.15 \
|
||||
--initial_warmup 1 --final_warmup 2 \
|
||||
--pruning_method magnitude
|
||||
```
|
||||
|
||||
### After fine-pruning
|
||||
|
||||
**Counting parameters**
|
||||
|
||||
Regularization based pruning methods (soft movement pruning and L0 regularization) rely on the penalty to induce sparsity. The multiplicative coefficient controls the sparsity level.
|
||||
To obtain the effective sparsity level in the encoder, we simply count the number of activated (non-null) weights:
|
||||
|
||||
```bash
|
||||
python examples/movement-pruning/counts_parameters.py \
|
||||
--pruning_method sigmoied_threshold \
|
||||
--threshold 0.1 \
|
||||
--serialization_dir $SERIALIZATION_DIR
|
||||
```
|
||||
|
||||
**Pruning once for all**
|
||||
|
||||
Once the model has been fine-pruned, the pruned weights can be set to 0. once for all (reducing the amount of information to store). In our running experiments, we can convert a `MaskedBertForQuestionAnswering` (a BERT model augmented to enable on-the-fly pruning capabilities) to a standard `BertForQuestionAnswering`:
|
||||
|
||||
```bash
|
||||
python examples/movement-pruning/bertarize.py \
|
||||
--pruning_method sigmoied_threshold \
|
||||
--threshold 0.1 \
|
||||
--model_name_or_path $SERIALIZATION_DIR
|
||||
```
|
||||
|
||||
## Hyper-parameters
|
||||
|
||||
For reproducibility purposes, we share the detailed results presented in the paper. These [tables](https://docs.google.com/spreadsheets/d/17JgRq_OFFTniUrz6BZWW_87DjFkKXpI1kYDSsseT_7g/edit?usp=sharing) exhaustively describe the individual hyper-parameters used for each data point.
|
||||
|
||||
## Inference speed
|
||||
|
||||
Early experiments show that even though models fine-pruned with (soft) movement pruning are extremely sparse, they do not benefit from significant improvement in terms of inference speed when using the standard PyTorch inference.
|
||||
We are currently benchmarking and exploring inference setups specifically for sparse architectures.
|
||||
In particular, hardware manufacturers are announcing devices that will speedup inference for sparse networks considerably.
|
||||
|
||||
## Citation
|
||||
|
||||
If you find this resource useful, please consider citing the following paper:
|
||||
|
||||
```
|
||||
@article{sanh2020movement,
|
||||
title={Movement Pruning: Adaptive Sparsity by Fine-Tuning},
|
||||
author={Victor Sanh and Thomas Wolf and Alexander M. Rush},
|
||||
year={2020},
|
||||
eprint={2005.07683},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,634 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Saving PruneBERT\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This notebook aims at showcasing how we can leverage standard tools to save (and load) an extremely sparse model fine-pruned with [movement pruning](https://arxiv.org/abs/2005.07683) (or any other unstructured pruning mehtod).\n",
|
||||
"\n",
|
||||
"In this example, we used BERT (base-uncased, but the procedure described here is not specific to BERT and can be applied to a large variety of models.\n",
|
||||
"\n",
|
||||
"We first obtain an extremely sparse model by fine-pruning with movement pruning on SQuAD v1.1. We then used the following combination of standard tools:\n",
|
||||
"- We reduce the precision of the model with Int8 dynamic quantization using [PyTorch implementation](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html). We only quantized the Fully Connected Layers.\n",
|
||||
"- Sparse quantized matrices are converted into the [Compressed Sparse Row format](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html).\n",
|
||||
"- We use HDF5 with `gzip` compression to store the weights.\n",
|
||||
"\n",
|
||||
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
|
||||
"\n",
|
||||
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">\n",
|
||||
"\n",
|
||||
"*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Includes\n",
|
||||
"\n",
|
||||
"import h5py\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"\n",
|
||||
"from scipy import sparse\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"\n",
|
||||
"from transformers import *\n",
|
||||
"\n",
|
||||
"os.chdir('../../')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Saving"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Dynamic quantization induces little or no loss of performance while significantly reducing the memory footprint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load fine-pruned model and quantize the model\n",
|
||||
"\n",
|
||||
"model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
|
||||
"model.to('cpu')\n",
|
||||
"\n",
|
||||
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
||||
" model=model,\n",
|
||||
" qconfig_spec = {\n",
|
||||
" torch.nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
|
||||
" },\n",
|
||||
" dtype=torch.qint8,\n",
|
||||
" )\n",
|
||||
"# print(quantized_model)\n",
|
||||
"\n",
|
||||
"qtz_st = quantized_model.state_dict()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Saving the original (encoder + classifier) in the standard torch.save format\n",
|
||||
"\n",
|
||||
"dense_st = {name: param for name, param in model.state_dict().items() \n",
|
||||
" if \"embedding\" not in name and \"pooler\" not in name}\n",
|
||||
"torch.save(dense_st, 'dbg/dense_squad.pt',)\n",
|
||||
"dense_mb_size = os.path.getsize(\"dbg/dense_squad.pt\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.0.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.1.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.2.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.3.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.4.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.5.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.6.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.7.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.8.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.9.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.10.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.self.query._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.self.key._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.self.value._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.attention.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.intermediate.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.encoder.layer.11.output.dense._packed_params.weight\n",
|
||||
"Decompose quantization for bert.pooler.dense._packed_params.weight\n",
|
||||
"Decompose quantization for qa_outputs._packed_params.weight\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Elementary representation: we decompose the quantized tensors into (scale, zero_point, int_repr).\n",
|
||||
"# See https://pytorch.org/docs/stable/quantization.html\n",
|
||||
"\n",
|
||||
"# We further leverage the fact that int_repr is sparse matrix to optimize the storage: we decompose int_repr into\n",
|
||||
"# its CSR representation (data, indptr, indices).\n",
|
||||
"\n",
|
||||
"elementary_qtz_st = {}\n",
|
||||
"for name, param in qtz_st.items():\n",
|
||||
" if \"dtype\" not in name and param.is_quantized:\n",
|
||||
" print(\"Decompose quantization for\", name)\n",
|
||||
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
|
||||
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
|
||||
" zero_point = param.q_zero_point() # torch.tensor(1,) - int32\n",
|
||||
" elementary_qtz_st[f\"{name}.scale\"] = scale\n",
|
||||
" elementary_qtz_st[f\"{name}.zero_point\"] = zero_point\n",
|
||||
"\n",
|
||||
" # We assume the int_repr is sparse and compute its CSR representation\n",
|
||||
" # Only the FCs in the encoder are actually sparse\n",
|
||||
" int_repr = param.int_repr() # torch.tensor(nb_rows, nb_columns) - int8\n",
|
||||
" int_repr_cs = sparse.csr_matrix(int_repr) # scipy.sparse.csr.csr_matrix\n",
|
||||
"\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.data\"] = int_repr_cs.data # np.array int8\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.indptr\"] = int_repr_cs.indptr # np.array int32\n",
|
||||
" assert max(int_repr_cs.indices) < 65535 # If not, we shall fall back to int32\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.indices\"] = np.uint16(int_repr_cs.indices) # np.array uint16\n",
|
||||
" elementary_qtz_st[f\"{name}.int_repr.shape\"] = int_repr_cs.shape # tuple(int, int)\n",
|
||||
" else:\n",
|
||||
" elementary_qtz_st[name] = param\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
|
||||
"str_2_dtype = {\"qint8\": torch.qint8}\n",
|
||||
"dtype_2_str = {torch.qint8: \"qint8\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Encoder Size (MB) - Sparse & Quantized - `torch.save`: 21.29\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Saving the pruned (encoder + classifier) in the standard torch.save format\n",
|
||||
"\n",
|
||||
"dense_optimized_st = {name: param for name, param in elementary_qtz_st.items() \n",
|
||||
" if \"embedding\" not in name and \"pooler\" not in name}\n",
|
||||
"torch.save(dense_optimized_st, 'dbg/dense_squad_optimized.pt',)\n",
|
||||
"print(\"Encoder Size (MB) - Sparse & Quantized - `torch.save`:\",\n",
|
||||
" round(os.path.getsize(\"dbg/dense_squad_optimized.pt\")/1e6, 2))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Skip bert.embeddings.word_embeddings.weight\n",
|
||||
"Skip bert.embeddings.position_embeddings.weight\n",
|
||||
"Skip bert.embeddings.token_type_embeddings.weight\n",
|
||||
"Skip bert.embeddings.LayerNorm.weight\n",
|
||||
"Skip bert.embeddings.LayerNorm.bias\n",
|
||||
"Skip bert.pooler.dense.scale\n",
|
||||
"Skip bert.pooler.dense.zero_point\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.scale\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.zero_point\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.data\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.indptr\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
|
||||
"Skip bert.pooler.dense._packed_params.bias\n",
|
||||
"Skip bert.pooler.dense._packed_params.dtype\n",
|
||||
"\n",
|
||||
"Encoder Size (MB) - Dense: 340.26\n",
|
||||
"Encoder Size (MB) - Sparse & Quantized: 11.28\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Save the decomposed state_dict with an HDF5 file\n",
|
||||
"# Saving only the encoder + QA Head\n",
|
||||
"\n",
|
||||
"with h5py.File('dbg/squad_sparse.h5','w') as hf:\n",
|
||||
" for name, param in elementary_qtz_st.items():\n",
|
||||
" if \"embedding\" in name:\n",
|
||||
" print(f\"Skip {name}\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if \"pooler\" in name:\n",
|
||||
" print(f\"Skip {name}\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" if param.numel() == 1:\n",
|
||||
" # module scale\n",
|
||||
" # module zero_point\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if param.requires_grad:\n",
|
||||
" # LayerNorm\n",
|
||||
" param = param.detach().numpy()\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
|
||||
" # float - tensor _packed_params.weight.scale\n",
|
||||
" # int - tensor _packed_params.weight.zero_point\n",
|
||||
" # tuple - tensor _packed_params.weight.shape\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
"\n",
|
||||
" elif type(param) == torch.dtype:\n",
|
||||
" # dtype - tensor _packed_params.dtype\n",
|
||||
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"with open('dbg/metadata.json', 'w') as f:\n",
|
||||
" f.write(json.dumps(qtz_st._metadata)) \n",
|
||||
"\n",
|
||||
"size = os.path.getsize(\"dbg/squad_sparse.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
|
||||
"print(\"\")\n",
|
||||
"print(\"Encoder Size (MB) - Dense: \", round(dense_mb_size/1e6, 2))\n",
|
||||
"print(\"Encoder Size (MB) - Sparse & Quantized:\", round(size/1e6, 2))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Size (MB): 99.41\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Save the decomposed state_dict to HDF5 storage\n",
|
||||
"# Save everything in the architecutre (embedding + encoder + QA Head)\n",
|
||||
"\n",
|
||||
"with h5py.File('dbg/squad_sparse_with_embs.h5','w') as hf:\n",
|
||||
" for name, param in elementary_qtz_st.items():\n",
|
||||
"# if \"embedding\" in name:\n",
|
||||
"# print(f\"Skip {name}\")\n",
|
||||
"# continue\n",
|
||||
"\n",
|
||||
"# if \"pooler\" in name:\n",
|
||||
"# print(f\"Skip {name}\")\n",
|
||||
"# continue\n",
|
||||
"\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" if param.numel() == 1:\n",
|
||||
" # module scale\n",
|
||||
" # module zero_point\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" if param.requires_grad:\n",
|
||||
" # LayerNorm\n",
|
||||
" param = param.detach().numpy()\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
|
||||
" # float - tensor _packed_params.weight.scale\n",
|
||||
" # int - tensor _packed_params.weight.zero_point\n",
|
||||
" # tuple - tensor _packed_params.weight.shape\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
"\n",
|
||||
" elif type(param) == torch.dtype:\n",
|
||||
" # dtype - tensor _packed_params.dtype\n",
|
||||
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"with open('dbg/metadata.json', 'w') as f:\n",
|
||||
" f.write(json.dumps(qtz_st._metadata)) \n",
|
||||
"\n",
|
||||
"size = os.path.getsize(\"dbg/squad_sparse_with_embs.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
|
||||
"print('\\nSize (MB):', round(size/1e6, 2))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reconstruct the elementary state dict\n",
|
||||
"\n",
|
||||
"reconstructed_elementary_qtz_st = {}\n",
|
||||
"\n",
|
||||
"hf = h5py.File('dbg/squad_sparse_with_embs.h5','r')\n",
|
||||
"\n",
|
||||
"for attr_name, attr_param in hf.attrs.items():\n",
|
||||
" if 'shape' in attr_name:\n",
|
||||
" attr_param = tuple(attr_param)\n",
|
||||
" elif \".scale\" in attr_name:\n",
|
||||
" if \"_packed_params\" in attr_name:\n",
|
||||
" attr_param = float(attr_param)\n",
|
||||
" else:\n",
|
||||
" attr_param = torch.tensor(attr_param)\n",
|
||||
" elif \".zero_point\" in attr_name:\n",
|
||||
" if \"_packed_params\" in attr_name:\n",
|
||||
" attr_param = int(attr_param)\n",
|
||||
" else:\n",
|
||||
" attr_param = torch.tensor(attr_param)\n",
|
||||
" elif \".dtype\" in attr_name:\n",
|
||||
" attr_param = str_2_dtype[attr_param]\n",
|
||||
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
|
||||
" # print(f\"Unpack {attr_name}\")\n",
|
||||
" \n",
|
||||
"# Get the tensors/arrays\n",
|
||||
"for data_name, data_param in hf.items():\n",
|
||||
" if \"LayerNorm\" in data_name or \"_packed_params.bias\" in data_name:\n",
|
||||
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
|
||||
" elif \"embedding\" in data_name:\n",
|
||||
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
|
||||
" else: # _packed_params.weight.int_repr.data, _packed_params.weight.int_repr.indices and _packed_params.weight.int_repr.indptr\n",
|
||||
" data_param = np.array(data_param)\n",
|
||||
" if \"indices\" in data_name:\n",
|
||||
" data_param = np.array(data_param, dtype=np.int32)\n",
|
||||
" reconstructed_elementary_qtz_st[data_name] = data_param\n",
|
||||
" # print(f\"Unpack {data_name}\")\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"hf.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Sanity checks\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_elementary_qtz_st.items():\n",
|
||||
" assert name in elementary_qtz_st\n",
|
||||
"for name, param in elementary_qtz_st.items():\n",
|
||||
" assert name in reconstructed_elementary_qtz_st, name\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_elementary_qtz_st.items():\n",
|
||||
" assert type(param) == type(elementary_qtz_st[name]), name\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" assert torch.all(torch.eq(param, elementary_qtz_st[name])), name\n",
|
||||
" elif type(param) == np.ndarray:\n",
|
||||
" assert (param == elementary_qtz_st[name]).all(), name\n",
|
||||
" else:\n",
|
||||
" assert param == elementary_qtz_st[name], name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Re-assemble the sparse int_repr from the CSR format\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_st = {}\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_elementary_qtz_st.items():\n",
|
||||
" if \"weight.int_repr.indptr\" in name:\n",
|
||||
" prefix_ = name[:-16]\n",
|
||||
" data = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.data\"]\n",
|
||||
" indptr = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indptr\"]\n",
|
||||
" indices = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indices\"]\n",
|
||||
" shape = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.shape\"]\n",
|
||||
"\n",
|
||||
" int_repr = sparse.csr_matrix(arg1=(data, indices, indptr),\n",
|
||||
" shape=shape)\n",
|
||||
" int_repr = torch.tensor(int_repr.todense())\n",
|
||||
"\n",
|
||||
" scale = reconstructed_elementary_qtz_st[f\"{prefix_}.scale\"]\n",
|
||||
" zero_point = reconstructed_elementary_qtz_st[f\"{prefix_}.zero_point\"]\n",
|
||||
" weight = torch._make_per_tensor_quantized_tensor(int_repr,\n",
|
||||
" scale,\n",
|
||||
" zero_point)\n",
|
||||
"\n",
|
||||
" reconstructed_qtz_st[f\"{prefix_}\"] = weight\n",
|
||||
" elif \"int_repr.data\" in name or \"int_repr.shape\" in name or \"int_repr.indices\" in name or \\\n",
|
||||
" \"weight.scale\" in name or \"weight.zero_point\" in name:\n",
|
||||
" continue\n",
|
||||
" else:\n",
|
||||
" reconstructed_qtz_st[name] = param\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Sanity checks\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_qtz_st.items():\n",
|
||||
" assert name in qtz_st\n",
|
||||
"for name, param in qtz_st.items():\n",
|
||||
" assert name in reconstructed_qtz_st, name\n",
|
||||
"\n",
|
||||
"for name, param in reconstructed_qtz_st.items():\n",
|
||||
" assert type(param) == type(qtz_st[name]), name\n",
|
||||
" if type(param) == torch.Tensor:\n",
|
||||
" assert torch.all(torch.eq(param, qtz_st[name])), name\n",
|
||||
" elif type(param) == np.ndarray:\n",
|
||||
" assert (param == qtz_st[name]).all(), name\n",
|
||||
" else:\n",
|
||||
" assert param == qtz_st[name], name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sanity checks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load the re-constructed state dict into a model\n",
|
||||
"\n",
|
||||
"dummy_model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')\n",
|
||||
"dummy_model.to('cpu')\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_model = torch.quantization.quantize_dynamic(\n",
|
||||
" model=dummy_model,\n",
|
||||
" qconfig_spec = None,\n",
|
||||
" dtype=torch.qint8,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_st = OrderedDict(reconstructed_qtz_st)\n",
|
||||
"with open('dbg/metadata.json', 'r') as read_file:\n",
|
||||
" metadata = json.loads(read_file.read())\n",
|
||||
"reconstructed_qtz_st._metadata = metadata\n",
|
||||
"\n",
|
||||
"reconstructed_qtz_model.load_state_dict(reconstructed_qtz_st)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sanity check passed\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Sanity checks on the infernce\n",
|
||||
"\n",
|
||||
"N = 32\n",
|
||||
"\n",
|
||||
"for _ in range(25):\n",
|
||||
" inputs = torch.randint(low=0, high=30000, size=(N, 128))\n",
|
||||
" mask = torch.ones(size=(N, 128))\n",
|
||||
"\n",
|
||||
" y_reconstructed = reconstructed_qtz_model(input_ids=inputs, attention_mask=mask)[0]\n",
|
||||
" y = quantized_model(input_ids=inputs, attention_mask=mask)[0]\n",
|
||||
" \n",
|
||||
" assert torch.all(torch.eq(y, y_reconstructed))\n",
|
||||
"print(\"Sanity check passed\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
132
examples/research_projects/movement-pruning/bertarize.py
Normal file
132
examples/research_projects/movement-pruning/bertarize.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright 2020-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all.
|
||||
For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded)
|
||||
as a standard :class:`~transformers.BertForSequenceClassification`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
def main(args):
|
||||
pruning_method = args.pruning_method
|
||||
threshold = args.threshold
|
||||
|
||||
model_name_or_path = args.model_name_or_path.rstrip("/")
|
||||
target_model_path = args.target_model_path
|
||||
|
||||
print(f"Load fine-pruned model from {model_name_or_path}")
|
||||
model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
|
||||
pruned_model = {}
|
||||
|
||||
for name, tensor in model.items():
|
||||
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Copied layer {name}")
|
||||
elif "classifier" in name or "qa_output" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Copied layer {name}")
|
||||
elif "bias" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Copied layer {name}")
|
||||
else:
|
||||
if pruning_method == "magnitude":
|
||||
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
elif pruning_method == "topK":
|
||||
if "mask_scores" in name:
|
||||
continue
|
||||
prefix_ = name[:-6]
|
||||
scores = model[f"{prefix_}mask_scores"]
|
||||
mask = TopKBinarizer.apply(scores, threshold)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
elif pruning_method == "sigmoied_threshold":
|
||||
if "mask_scores" in name:
|
||||
continue
|
||||
prefix_ = name[:-6]
|
||||
scores = model[f"{prefix_}mask_scores"]
|
||||
mask = ThresholdBinarizer.apply(scores, threshold, True)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
elif pruning_method == "l0":
|
||||
if "mask_scores" in name:
|
||||
continue
|
||||
prefix_ = name[:-6]
|
||||
scores = model[f"{prefix_}mask_scores"]
|
||||
l, r = -0.1, 1.1
|
||||
s = torch.sigmoid(scores)
|
||||
s_bar = s * (r - l) + l
|
||||
mask = s_bar.clamp(min=0.0, max=1.0)
|
||||
pruned_model[name] = tensor * mask
|
||||
print(f"Pruned layer {name}")
|
||||
else:
|
||||
raise ValueError("Unknown pruning method")
|
||||
|
||||
if target_model_path is None:
|
||||
target_model_path = os.path.join(
|
||||
os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}"
|
||||
)
|
||||
|
||||
if not os.path.isdir(target_model_path):
|
||||
shutil.copytree(model_name_or_path, target_model_path)
|
||||
print(f"\nCreated folder {target_model_path}")
|
||||
|
||||
torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin"))
|
||||
print("\nPruned model saved! See you later!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--pruning_method",
|
||||
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
|
||||
type=str,
|
||||
required=True,
|
||||
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
required=False,
|
||||
help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
||||
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
||||
"Not needed for `l0`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Folder containing the model that was previously fine-pruned",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_model_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Folder containing the model that was previously fine-pruned",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright 2020-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
|
||||
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
def main(args):
|
||||
serialization_dir = args.serialization_dir
|
||||
pruning_method = args.pruning_method
|
||||
threshold = args.threshold
|
||||
|
||||
st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu")
|
||||
|
||||
remaining_count = 0 # Number of remaining (not pruned) params in the encoder
|
||||
encoder_count = 0 # Number of params in the encoder
|
||||
|
||||
print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight")
|
||||
for name, param in st.items():
|
||||
if "encoder" not in name:
|
||||
continue
|
||||
|
||||
if "mask_scores" in name:
|
||||
if pruning_method == "topK":
|
||||
mask_ones = TopKBinarizer.apply(param, threshold).sum().item()
|
||||
elif pruning_method == "sigmoied_threshold":
|
||||
mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item()
|
||||
elif pruning_method == "l0":
|
||||
l, r = -0.1, 1.1
|
||||
s = torch.sigmoid(param)
|
||||
s_bar = s * (r - l) + l
|
||||
mask = s_bar.clamp(min=0.0, max=1.0)
|
||||
mask_ones = (mask > 0.0).sum().item()
|
||||
else:
|
||||
raise ValueError("Unknown pruning method")
|
||||
remaining_count += mask_ones
|
||||
print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones))
|
||||
else:
|
||||
encoder_count += param.numel()
|
||||
if "bias" in name or "LayerNorm" in name:
|
||||
remaining_count += param.numel()
|
||||
|
||||
print("")
|
||||
print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--pruning_method",
|
||||
choices=["l0", "topK", "sigmoied_threshold"],
|
||||
type=str,
|
||||
required=True,
|
||||
help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
required=False,
|
||||
help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
||||
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
||||
"Not needed for `l0`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serialization_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Folder containing the model that was previously fine-pruned",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -0,0 +1,10 @@
|
||||
# flake8: noqa
|
||||
from .configuration_bert_masked import MaskedBertConfig
|
||||
from .modeling_bert_masked import (
|
||||
MaskedBertForMultipleChoice,
|
||||
MaskedBertForQuestionAnswering,
|
||||
MaskedBertForSequenceClassification,
|
||||
MaskedBertForTokenClassification,
|
||||
MaskedBertModel,
|
||||
)
|
||||
from .modules import *
|
||||
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Masked BERT model configuration. It replicates the class `~transformers.BertConfig`
|
||||
and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init` and `mask_scale`."""
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MaskedBertConfig(PretrainedConfig):
|
||||
"""
|
||||
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
|
||||
"""
|
||||
|
||||
model_type = "masked_bert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
pruning_method="topK",
|
||||
mask_init="constant",
|
||||
mask_scale=0.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.pruning_method = pruning_method
|
||||
self.mask_init = mask_init
|
||||
self.mask_scale = mask_scale
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
# flake8: noqa
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
from .masked_nn import MaskedLinear
|
||||
@@ -0,0 +1,144 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present, AllenAI Authors, University of Illinois Urbana-Champaign,
|
||||
# Intel Nervana Systems and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Binarizers take a (real value) matrix as input and produce a binary (values in {0,1}) mask of the same shape.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import autograd
|
||||
|
||||
|
||||
class ThresholdBinarizer(autograd.Function):
|
||||
"""
|
||||
Thresholdd binarizer.
|
||||
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j} > \tau`
|
||||
where `\tau` is a real value threshold.
|
||||
|
||||
Implementation is inspired from:
|
||||
https://github.com/arunmallya/piggyback
|
||||
Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights
|
||||
Arun Mallya, Dillon Davis, Svetlana Lazebnik
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
|
||||
"""
|
||||
Args:
|
||||
inputs (`torch.FloatTensor`)
|
||||
The input matrix from which the binarizer computes the binary mask.
|
||||
threshold (`float`)
|
||||
The threshold value (in R).
|
||||
sigmoid (`bool`)
|
||||
If set to ``True``, we apply the sigmoid function to the `inputs` matrix before comparing to `threshold`.
|
||||
In this case, `threshold` should be a value between 0 and 1.
|
||||
Returns:
|
||||
mask (`torch.FloatTensor`)
|
||||
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
|
||||
retained, 0 - the associated weight is pruned).
|
||||
"""
|
||||
nb_elems = inputs.numel()
|
||||
nb_min = int(0.005 * nb_elems) + 1
|
||||
if sigmoid:
|
||||
mask = (torch.sigmoid(inputs) > threshold).type(inputs.type())
|
||||
else:
|
||||
mask = (inputs > threshold).type(inputs.type())
|
||||
if mask.sum() < nb_min:
|
||||
# We limit the pruning so that at least 0.5% (half a percent) of the weights are remaining
|
||||
k_threshold = inputs.flatten().kthvalue(max(nb_elems - nb_min, 1)).values
|
||||
mask = (inputs > k_threshold).type(inputs.type())
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradOutput):
|
||||
return gradOutput, None, None
|
||||
|
||||
|
||||
class TopKBinarizer(autograd.Function):
|
||||
"""
|
||||
Top-k Binarizer.
|
||||
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
|
||||
is among the k% highest values of S.
|
||||
|
||||
Implementation is inspired from:
|
||||
https://github.com/allenai/hidden-networks
|
||||
What's hidden in a randomly weighted neural network?
|
||||
Vivek Ramanujan*, Mitchell Wortsman*, Aniruddha Kembhavi, Ali Farhadi, Mohammad Rastegari
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs: torch.tensor, threshold: float):
|
||||
"""
|
||||
Args:
|
||||
inputs (`torch.FloatTensor`)
|
||||
The input matrix from which the binarizer computes the binary mask.
|
||||
threshold (`float`)
|
||||
The percentage of weights to keep (the rest is pruned).
|
||||
`threshold` is a float between 0 and 1.
|
||||
Returns:
|
||||
mask (`torch.FloatTensor`)
|
||||
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
|
||||
retained, 0 - the associated weight is pruned).
|
||||
"""
|
||||
# Get the subnetwork by sorting the inputs and using the top threshold %
|
||||
mask = inputs.clone()
|
||||
_, idx = inputs.flatten().sort(descending=True)
|
||||
j = int(threshold * inputs.numel())
|
||||
|
||||
# flat_out and mask access the same memory.
|
||||
flat_out = mask.flatten()
|
||||
flat_out[idx[j:]] = 0
|
||||
flat_out[idx[:j]] = 1
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradOutput):
|
||||
return gradOutput, None
|
||||
|
||||
|
||||
class MagnitudeBinarizer(object):
|
||||
"""
|
||||
Magnitude Binarizer.
|
||||
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
|
||||
is among the k% highest values of |S| (absolute value).
|
||||
|
||||
Implementation is inspired from https://github.com/NervanaSystems/distiller/blob/2291fdcc2ea642a98d4e20629acb5a9e2e04b4e6/distiller/pruning/automated_gradual_pruner.py#L24
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def apply(inputs: torch.tensor, threshold: float):
|
||||
"""
|
||||
Args:
|
||||
inputs (`torch.FloatTensor`)
|
||||
The input matrix from which the binarizer computes the binary mask.
|
||||
This input marix is typically the weight matrix.
|
||||
threshold (`float`)
|
||||
The percentage of weights to keep (the rest is pruned).
|
||||
`threshold` is a float between 0 and 1.
|
||||
Returns:
|
||||
mask (`torch.FloatTensor`)
|
||||
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
|
||||
retained, 0 - the associated weight is pruned).
|
||||
"""
|
||||
# Get the subnetwork by sorting the inputs and using the top threshold %
|
||||
mask = inputs.clone()
|
||||
_, idx = inputs.abs().flatten().sort(descending=True)
|
||||
j = int(threshold * inputs.numel())
|
||||
|
||||
# flat_out and mask access the same memory.
|
||||
flat_out = mask.flatten()
|
||||
flat_out[idx[j:]] = 0
|
||||
flat_out[idx[:j]] = 1
|
||||
return mask
|
||||
@@ -0,0 +1,107 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Masked Linear module: A fully connected layer that computes an adaptive binary mask on the fly.
|
||||
The mask (binary or not) is computed at each forward pass and multiplied against
|
||||
the weight matrix to prune a portion of the weights.
|
||||
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init
|
||||
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
class MaskedLinear(nn.Linear):
|
||||
"""
|
||||
Fully Connected layer with on the fly adaptive mask.
|
||||
If needed, a score matrix is created to store the importance of each associated weight.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
mask_init: str = "constant",
|
||||
mask_scale: float = 0.0,
|
||||
pruning_method: str = "topK",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_features (`int`)
|
||||
Size of each input sample
|
||||
out_features (`int`)
|
||||
Size of each output sample
|
||||
bias (`bool`)
|
||||
If set to ``False``, the layer will not learn an additive bias.
|
||||
Default: ``True``
|
||||
mask_init (`str`)
|
||||
The initialization method for the score matrix if a score matrix is needed.
|
||||
Choices: ["constant", "uniform", "kaiming"]
|
||||
Default: ``constant``
|
||||
mask_scale (`float`)
|
||||
The initialization parameter for the chosen initialization method `mask_init`.
|
||||
Default: ``0.``
|
||||
pruning_method (`str`)
|
||||
Method to compute the mask.
|
||||
Choices: ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"]
|
||||
Default: ``topK``
|
||||
"""
|
||||
super(MaskedLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
|
||||
assert pruning_method in ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"]
|
||||
self.pruning_method = pruning_method
|
||||
|
||||
if self.pruning_method in ["topK", "threshold", "sigmoied_threshold", "l0"]:
|
||||
self.mask_scale = mask_scale
|
||||
self.mask_init = mask_init
|
||||
self.mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
|
||||
self.init_mask()
|
||||
|
||||
def init_mask(self):
|
||||
if self.mask_init == "constant":
|
||||
init.constant_(self.mask_scores, val=self.mask_scale)
|
||||
elif self.mask_init == "uniform":
|
||||
init.uniform_(self.mask_scores, a=-self.mask_scale, b=self.mask_scale)
|
||||
elif self.mask_init == "kaiming":
|
||||
init.kaiming_uniform_(self.mask_scores, a=math.sqrt(5))
|
||||
|
||||
def forward(self, input: torch.tensor, threshold: float):
|
||||
# Get the mask
|
||||
if self.pruning_method == "topK":
|
||||
mask = TopKBinarizer.apply(self.mask_scores, threshold)
|
||||
elif self.pruning_method in ["threshold", "sigmoied_threshold"]:
|
||||
sig = "sigmoied" in self.pruning_method
|
||||
mask = ThresholdBinarizer.apply(self.mask_scores, threshold, sig)
|
||||
elif self.pruning_method == "magnitude":
|
||||
mask = MagnitudeBinarizer.apply(self.weight, threshold)
|
||||
elif self.pruning_method == "l0":
|
||||
l, r, b = -0.1, 1.1, 2 / 3
|
||||
if self.training:
|
||||
u = torch.zeros_like(self.mask_scores).uniform_().clamp(0.0001, 0.9999)
|
||||
s = torch.sigmoid((u.log() - (1 - u).log() + self.mask_scores) / b)
|
||||
else:
|
||||
s = torch.sigmoid(self.mask_scores)
|
||||
s_bar = s * (r - l) + l
|
||||
mask = s_bar.clamp(min=0.0, max=1.0)
|
||||
# Mask weights with computed mask
|
||||
weight_thresholded = mask * self.weight
|
||||
# Compute output (linear layer) with masked weights
|
||||
return F.linear(input, weight_thresholded, self.bias)
|
||||
@@ -0,0 +1,5 @@
|
||||
# LXMERT DEMO
|
||||
|
||||
1. make a virtualenv: ``virtualenv venv`` and activate ``source venv/bin/activate``
|
||||
2. install reqs: ``pip install -r ./requirements.txt``
|
||||
3. usage is as shown in demo.ipynb
|
||||
267
examples/research_projects/movement-pruning/lxmert/demo.ipynb
Normal file
267
examples/research_projects/movement-pruning/lxmert/demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,149 @@
|
||||
import getopt
|
||||
import json
|
||||
import os
|
||||
|
||||
# import numpy as np
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modeling_frcnn import GeneralizedRCNN
|
||||
from processing_image import Preprocess
|
||||
from utils import Config
|
||||
|
||||
|
||||
"""
|
||||
USAGE:
|
||||
``python extracting_data.py -i <img_dir> -o <dataset_file>.datasets <batch_size>``
|
||||
"""
|
||||
|
||||
|
||||
TEST = False
|
||||
CONFIG = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
|
||||
DEFAULT_SCHEMA = datasets.Features(
|
||||
OrderedDict(
|
||||
{
|
||||
"attr_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"attr_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"boxes": datasets.Array2D((CONFIG.MAX_DETECTIONS, 4), dtype="float32"),
|
||||
"img_id": datasets.Value("int32"),
|
||||
"obj_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"obj_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
|
||||
"roi_features": datasets.Array2D((CONFIG.MAX_DETECTIONS, 2048), dtype="float32"),
|
||||
"sizes": datasets.Sequence(length=2, feature=datasets.Value("float32")),
|
||||
"preds_per_image": datasets.Value(dtype="int32"),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Extract:
|
||||
def __init__(self, argv=sys.argv[1:]):
|
||||
inputdir = None
|
||||
outputfile = None
|
||||
subset_list = None
|
||||
batch_size = 1
|
||||
opts, args = getopt.getopt(argv, "i:o:b:s", ["inputdir=", "outfile=", "batch_size=", "subset_list="])
|
||||
for opt, arg in opts:
|
||||
if opt in ("-i", "--inputdir"):
|
||||
inputdir = arg
|
||||
elif opt in ("-o", "--outfile"):
|
||||
outputfile = arg
|
||||
elif opt in ("-b", "--batch_size"):
|
||||
batch_size = int(arg)
|
||||
elif opt in ("-s", "--subset_list"):
|
||||
subset_list = arg
|
||||
|
||||
assert inputdir is not None # and os.path.isdir(inputdir), f"{inputdir}"
|
||||
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}"
|
||||
if subset_list is not None:
|
||||
with open(os.path.realpath(subset_list)) as f:
|
||||
self.subset_list = set(map(lambda x: self._vqa_file_split()[0], tryload(f)))
|
||||
else:
|
||||
self.subset_list = None
|
||||
|
||||
self.config = CONFIG
|
||||
if torch.cuda.is_available():
|
||||
self.config.model.device = "cuda"
|
||||
self.inputdir = os.path.realpath(inputdir)
|
||||
self.outputfile = os.path.realpath(outputfile)
|
||||
self.preprocess = Preprocess(self.config)
|
||||
self.model = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.config)
|
||||
self.batch = batch_size if batch_size != 0 else 1
|
||||
self.schema = DEFAULT_SCHEMA
|
||||
|
||||
def _vqa_file_split(self, file):
|
||||
img_id = int(file.split(".")[0].split("_")[-1])
|
||||
filepath = os.path.join(self.inputdir, file)
|
||||
return (img_id, filepath)
|
||||
|
||||
@property
|
||||
def file_generator(self):
|
||||
batch = []
|
||||
for i, file in enumerate(os.listdir(self.inputdir)):
|
||||
if self.subset_list is not None and i not in self.subset_list:
|
||||
continue
|
||||
batch.append(self._vqa_file_split(file))
|
||||
if len(batch) == self.batch:
|
||||
temp = batch
|
||||
batch = []
|
||||
yield list(map(list, zip(*temp)))
|
||||
|
||||
for i in range(1):
|
||||
yield list(map(list, zip(*batch)))
|
||||
|
||||
def __call__(self):
|
||||
# make writer
|
||||
if not TEST:
|
||||
writer = datasets.ArrowWriter(features=self.schema, path=self.outputfile)
|
||||
# do file generator
|
||||
for i, (img_ids, filepaths) in enumerate(self.file_generator):
|
||||
images, sizes, scales_yx = self.preprocess(filepaths)
|
||||
output_dict = self.model(
|
||||
images,
|
||||
sizes,
|
||||
scales_yx=scales_yx,
|
||||
padding="max_detections",
|
||||
max_detections=self.config.MAX_DETECTIONS,
|
||||
pad_value=0,
|
||||
return_tensors="np",
|
||||
location="cpu",
|
||||
)
|
||||
output_dict["boxes"] = output_dict.pop("normalized_boxes")
|
||||
if not TEST:
|
||||
output_dict["img_id"] = np.array(img_ids)
|
||||
batch = self.schema.encode_batch(output_dict)
|
||||
writer.write_batch(batch)
|
||||
if TEST:
|
||||
break
|
||||
# finalizer the writer
|
||||
if not TEST:
|
||||
num_examples, num_bytes = writer.finalize()
|
||||
print(f"Success! You wrote {num_examples} entry(s) and {num_bytes >> 20} mb")
|
||||
|
||||
|
||||
def tryload(stream):
|
||||
try:
|
||||
data = json.load(stream)
|
||||
try:
|
||||
data = list(data.keys())
|
||||
except Exception:
|
||||
data = [d["img_id"] for d in data]
|
||||
except Exception:
|
||||
try:
|
||||
data = eval(stream.read())
|
||||
except Exception:
|
||||
data = stream.read().split("\n")
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
extract = Extract(sys.argv[1:])
|
||||
extract()
|
||||
if not TEST:
|
||||
dataset = datasets.Dataset.from_file(extract.outputfile)
|
||||
# wala!
|
||||
# print(np.array(dataset[0:2]["roi_features"]).shape)
|
||||
1922
examples/research_projects/movement-pruning/lxmert/modeling_frcnn.py
Normal file
1922
examples/research_projects/movement-pruning/lxmert/modeling_frcnn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
coding=utf-8
|
||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
||||
Adapted From Facebook Inc, Detectron2
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.import copy
|
||||
"""
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from utils import img_tensorize
|
||||
|
||||
|
||||
class ResizeShortestEdge:
|
||||
def __init__(self, short_edge_length, max_size=sys.maxsize):
|
||||
"""
|
||||
Args:
|
||||
short_edge_length (list[min, max])
|
||||
max_size (int): maximum allowed longest edge length.
|
||||
"""
|
||||
self.interp_method = "bilinear"
|
||||
self.max_size = max_size
|
||||
self.short_edge_length = short_edge_length
|
||||
|
||||
def __call__(self, imgs):
|
||||
img_augs = []
|
||||
for img in imgs:
|
||||
h, w = img.shape[:2]
|
||||
# later: provide list and randomly choose index for resize
|
||||
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
||||
if size == 0:
|
||||
return img
|
||||
scale = size * 1.0 / min(h, w)
|
||||
if h < w:
|
||||
newh, neww = size, scale * w
|
||||
else:
|
||||
newh, neww = scale * h, size
|
||||
if max(newh, neww) > self.max_size:
|
||||
scale = self.max_size * 1.0 / max(newh, neww)
|
||||
newh = newh * scale
|
||||
neww = neww * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
|
||||
if img.dtype == np.uint8:
|
||||
pil_image = Image.fromarray(img)
|
||||
pil_image = pil_image.resize((neww, newh), Image.BILINEAR)
|
||||
img = np.asarray(pil_image)
|
||||
else:
|
||||
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
|
||||
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
|
||||
img_augs.append(img)
|
||||
|
||||
return img_augs
|
||||
|
||||
|
||||
class Preprocess:
|
||||
def __init__(self, cfg):
|
||||
self.aug = ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
|
||||
self.input_format = cfg.INPUT.FORMAT
|
||||
self.size_divisibility = cfg.SIZE_DIVISIBILITY
|
||||
self.pad_value = cfg.PAD_VALUE
|
||||
self.max_image_size = cfg.INPUT.MAX_SIZE_TEST
|
||||
self.device = cfg.MODEL.DEVICE
|
||||
self.pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
||||
self.pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
||||
self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
|
||||
|
||||
def pad(self, images):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||
image_sizes = [im.shape[-2:] for im in images]
|
||||
images = [
|
||||
F.pad(
|
||||
im,
|
||||
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
|
||||
value=self.pad_value,
|
||||
)
|
||||
for size, im in zip(image_sizes, images)
|
||||
]
|
||||
|
||||
return torch.stack(images), torch.tensor(image_sizes)
|
||||
|
||||
def __call__(self, images, single_image=False):
|
||||
with torch.no_grad():
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
if single_image:
|
||||
assert len(images) == 1
|
||||
for i in range(len(images)):
|
||||
if isinstance(images[i], torch.Tensor):
|
||||
images.insert(i, images.pop(i).to(self.device).float())
|
||||
elif not isinstance(images[i], torch.Tensor):
|
||||
images.insert(
|
||||
i,
|
||||
torch.as_tensor(img_tensorize(images.pop(i), input_format=self.input_format))
|
||||
.to(self.device)
|
||||
.float(),
|
||||
)
|
||||
# resize smallest edge
|
||||
raw_sizes = torch.tensor([im.shape[:2] for im in images])
|
||||
images = self.aug(images)
|
||||
# transpose images and convert to torch tensors
|
||||
# images = [torch.as_tensor(i.astype("float32")).permute(2, 0, 1).to(self.device) for i in images]
|
||||
# now normalize before pad to avoid useless arithmetic
|
||||
images = [self.normalizer(x) for x in images]
|
||||
# now pad them to do the following operations
|
||||
images, sizes = self.pad(images)
|
||||
# Normalize
|
||||
|
||||
if self.size_divisibility > 0:
|
||||
raise NotImplementedError()
|
||||
# pad
|
||||
scales_yx = torch.true_divide(raw_sizes, sizes)
|
||||
if single_image:
|
||||
return images[0], sizes[0], scales_yx[0]
|
||||
else:
|
||||
return images, sizes, scales_yx
|
||||
|
||||
|
||||
def _scale_box(boxes, scale_yx):
|
||||
boxes[:, 0::2] *= scale_yx[:, 1]
|
||||
boxes[:, 1::2] *= scale_yx[:, 0]
|
||||
return boxes
|
||||
|
||||
|
||||
def _clip_box(tensor, box_size: Tuple[int, int]):
|
||||
assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
|
||||
h, w = box_size
|
||||
tensor[:, 0].clamp_(min=0, max=w)
|
||||
tensor[:, 1].clamp_(min=0, max=h)
|
||||
tensor[:, 2].clamp_(min=0, max=w)
|
||||
tensor[:, 3].clamp_(min=0, max=h)
|
||||
@@ -0,0 +1,99 @@
|
||||
appdirs==1.4.3
|
||||
argon2-cffi==20.1.0
|
||||
async-generator==1.10
|
||||
attrs==20.2.0
|
||||
backcall==0.2.0
|
||||
bleach==3.1.5
|
||||
CacheControl==0.12.6
|
||||
certifi==2020.6.20
|
||||
cffi==1.14.2
|
||||
chardet==3.0.4
|
||||
click==7.1.2
|
||||
colorama==0.4.3
|
||||
contextlib2==0.6.0
|
||||
cycler==0.10.0
|
||||
datasets==1.0.0
|
||||
decorator==4.4.2
|
||||
defusedxml==0.6.0
|
||||
dill==0.3.2
|
||||
distlib==0.3.0
|
||||
distro==1.4.0
|
||||
entrypoints==0.3
|
||||
filelock==3.0.12
|
||||
future==0.18.2
|
||||
html5lib==1.0.1
|
||||
idna==2.8
|
||||
ipaddr==2.2.0
|
||||
ipykernel==5.3.4
|
||||
ipython
|
||||
ipython-genutils==0.2.0
|
||||
ipywidgets==7.5.1
|
||||
jedi==0.17.2
|
||||
Jinja2==2.11.2
|
||||
joblib==0.16.0
|
||||
jsonschema==3.2.0
|
||||
jupyter==1.0.0
|
||||
jupyter-client==6.1.7
|
||||
jupyter-console==6.2.0
|
||||
jupyter-core==4.6.3
|
||||
jupyterlab-pygments==0.1.1
|
||||
kiwisolver==1.2.0
|
||||
lockfile==0.12.2
|
||||
MarkupSafe==1.1.1
|
||||
matplotlib==3.3.1
|
||||
mistune==0.8.4
|
||||
msgpack==0.6.2
|
||||
nbclient==0.5.0
|
||||
nbconvert==6.0.1
|
||||
nbformat==5.0.7
|
||||
nest-asyncio==1.4.0
|
||||
notebook==6.1.4
|
||||
numpy==1.19.2
|
||||
opencv-python==4.4.0.42
|
||||
packaging==20.3
|
||||
pandas==1.1.2
|
||||
pandocfilters==1.4.2
|
||||
parso==0.7.1
|
||||
pep517==0.8.2
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==7.2.0
|
||||
progress==1.5
|
||||
prometheus-client==0.8.0
|
||||
prompt-toolkit==3.0.7
|
||||
ptyprocess==0.6.0
|
||||
pyaml==20.4.0
|
||||
pyarrow==1.0.1
|
||||
pycparser==2.20
|
||||
Pygments==2.6.1
|
||||
pyparsing==2.4.6
|
||||
pyrsistent==0.16.0
|
||||
python-dateutil==2.8.1
|
||||
pytoml==0.1.21
|
||||
pytz==2020.1
|
||||
PyYAML==5.3.1
|
||||
pyzmq==19.0.2
|
||||
qtconsole==4.7.7
|
||||
QtPy==1.9.0
|
||||
regex==2020.7.14
|
||||
requests==2.22.0
|
||||
retrying==1.3.3
|
||||
sacremoses==0.0.43
|
||||
Send2Trash==1.5.0
|
||||
sentencepiece==0.1.91
|
||||
six==1.14.0
|
||||
terminado==0.8.3
|
||||
testpath==0.4.4
|
||||
tokenizers==0.8.1rc2
|
||||
torch==1.6.0
|
||||
torchvision==0.7.0
|
||||
tornado==6.0.4
|
||||
tqdm==4.48.2
|
||||
traitlets
|
||||
transformers==3.5.1
|
||||
urllib3==1.25.8
|
||||
wcwidth==0.2.5
|
||||
webencodings==0.5.1
|
||||
wget==3.2
|
||||
widgetsnbextension==3.5.1
|
||||
xxhash==2.0.0
|
||||
559
examples/research_projects/movement-pruning/lxmert/utils.py
Normal file
559
examples/research_projects/movement-pruning/lxmert/utils.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""
|
||||
coding=utf-8
|
||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :)
|
||||
Adapted From Facebook Inc, Detectron2
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.import copy
|
||||
"""
|
||||
|
||||
import copy
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
import pickle as pkl
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
import wget
|
||||
from filelock import FileLock
|
||||
from yaml import Loader, dump, load
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
_torch_available = True
|
||||
except ImportError:
|
||||
_torch_available = False
|
||||
|
||||
|
||||
try:
|
||||
from torch.hub import _get_torch_home
|
||||
|
||||
torch_cache_home = _get_torch_home()
|
||||
except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
)
|
||||
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||
PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1])
|
||||
CONFIG = os.path.join(PATH, "config.yaml")
|
||||
ATTRIBUTES = os.path.join(PATH, "attributes.txt")
|
||||
OBJECTS = os.path.join(PATH, "objects.txt")
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
CONFIG_NAME = "config.yaml"
|
||||
|
||||
|
||||
def load_labels(objs=OBJECTS, attrs=ATTRIBUTES):
|
||||
vg_classes = []
|
||||
with open(objs) as f:
|
||||
for object in f.readlines():
|
||||
vg_classes.append(object.split(",")[0].lower().strip())
|
||||
|
||||
vg_attrs = []
|
||||
with open(attrs) as f:
|
||||
for object in f.readlines():
|
||||
vg_attrs.append(object.split(",")[0].lower().strip())
|
||||
return vg_classes, vg_attrs
|
||||
|
||||
|
||||
def load_checkpoint(ckp):
|
||||
r = OrderedDict()
|
||||
with open(ckp, "rb") as f:
|
||||
ckp = pkl.load(f)["model"]
|
||||
for k in copy.deepcopy(list(ckp.keys())):
|
||||
v = ckp.pop(k)
|
||||
if isinstance(v, np.ndarray):
|
||||
v = torch.tensor(v)
|
||||
else:
|
||||
assert isinstance(v, torch.tensor), type(v)
|
||||
r[k] = v
|
||||
return r
|
||||
|
||||
|
||||
class Config:
|
||||
_pointer = {}
|
||||
|
||||
def __init__(self, dictionary: dict, name: str = "root", level=0):
|
||||
self._name = name
|
||||
self._level = level
|
||||
d = {}
|
||||
for k, v in dictionary.items():
|
||||
if v is None:
|
||||
raise ValueError()
|
||||
k = copy.deepcopy(k)
|
||||
v = copy.deepcopy(v)
|
||||
if isinstance(v, dict):
|
||||
v = Config(v, name=k, level=level + 1)
|
||||
d[k] = v
|
||||
setattr(self, k, v)
|
||||
|
||||
self._pointer = d
|
||||
|
||||
def __repr__(self):
|
||||
return str(list((self._pointer.keys())))
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
self.__dict__[key] = val
|
||||
self.__dict__[key.upper()] = val
|
||||
levels = key.split(".")
|
||||
last_level = len(levels) - 1
|
||||
pointer = self._pointer
|
||||
if len(levels) > 1:
|
||||
for i, l in enumerate(levels):
|
||||
if hasattr(self, l) and isinstance(getattr(self, l), Config):
|
||||
setattr(getattr(self, l), ".".join(levels[i:]), val)
|
||||
if l == last_level:
|
||||
pointer[l] = val
|
||||
else:
|
||||
pointer = pointer[l]
|
||||
|
||||
def to_dict(self):
|
||||
return self._pointer
|
||||
|
||||
def dump_yaml(self, data, file_name):
|
||||
with open(f"{file_name}", "w") as stream:
|
||||
dump(data, stream)
|
||||
|
||||
def dump_json(self, data, file_name):
|
||||
with open(f"{file_name}", "w") as stream:
|
||||
json.dump(data, stream)
|
||||
|
||||
@staticmethod
|
||||
def load_yaml(config):
|
||||
with open(config) as stream:
|
||||
data = load(stream, Loader=Loader)
|
||||
return data
|
||||
|
||||
def __str__(self):
|
||||
t = " "
|
||||
if self._name != "root":
|
||||
r = f"{t * (self._level-1)}{self._name}:\n"
|
||||
else:
|
||||
r = ""
|
||||
level = self._level
|
||||
for i, (k, v) in enumerate(self._pointer.items()):
|
||||
if isinstance(v, Config):
|
||||
r += f"{t * (self._level)}{v}\n"
|
||||
self._level += 1
|
||||
else:
|
||||
r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n"
|
||||
self._level = level
|
||||
return r[:-1]
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
return cls(config_dict)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
# Load config dict
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError
|
||||
|
||||
config_file = Config.load_yaml(resolved_config_file)
|
||||
|
||||
except EnvironmentError:
|
||||
msg = "Can't load config for"
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
print("loading configuration file from path")
|
||||
else:
|
||||
print("loading configuration file cache")
|
||||
|
||||
return Config.load_yaml(resolved_config_file), kwargs
|
||||
|
||||
|
||||
# quick compare tensors
|
||||
def compare(in_tensor):
|
||||
|
||||
out_tensor = torch.load("dump.pt", map_location=in_tensor.device)
|
||||
n1 = in_tensor.numpy()
|
||||
n2 = out_tensor.numpy()[0]
|
||||
print(n1.shape, n1[0, 0, :5])
|
||||
print(n2.shape, n2[0, 0, :5])
|
||||
assert np.allclose(
|
||||
n1, n2, rtol=0.01, atol=0.1
|
||||
), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
|
||||
raise Exception("tensors are all good")
|
||||
|
||||
# Hugging face functions below
|
||||
|
||||
|
||||
def is_remote_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
|
||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{endpoint}/{model_id}-{filename}"
|
||||
else:
|
||||
return f"{endpoint}/{model_id}/{filename}"
|
||||
|
||||
|
||||
def http_get(
|
||||
url,
|
||||
temp_file,
|
||||
proxies=None,
|
||||
resume_size=0,
|
||||
user_agent=None,
|
||||
):
|
||||
ua = "python/{}".format(sys.version.split()[0])
|
||||
if _torch_available:
|
||||
ua += "; torch/{}".format(torch.__version__)
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
headers = {"user-agent": ua}
|
||||
if resume_size > 0:
|
||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
if response.status_code == 416: # Range not satisfiable
|
||||
return
|
||||
content_length = response.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
progress = tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc="Downloading",
|
||||
)
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(
|
||||
url,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
etag_timeout=10,
|
||||
resume_download=False,
|
||||
user_agent=None,
|
||||
local_files_only=False,
|
||||
):
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
||||
# try to get the last downloaded one
|
||||
if etag is None:
|
||||
if os.path.exists(cache_path):
|
||||
return cache_path
|
||||
else:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if len(matching_files) > 0:
|
||||
return os.path.join(cache_dir, matching_files[-1])
|
||||
else:
|
||||
# If files cannot be found and local_files_only=True,
|
||||
# the models might've been found if local_files_only=False
|
||||
# Notify the user about that
|
||||
if local_files_only:
|
||||
raise ValueError(
|
||||
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
||||
" to False."
|
||||
)
|
||||
return None
|
||||
|
||||
# From now on, etag is not None.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
return cache_path
|
||||
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
# If the download just completed while the lock was activated.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
# Even if returning early like here, the lock will be released.
|
||||
return cache_path
|
||||
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + ".incomplete"
|
||||
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path, "a+b") as f:
|
||||
yield f
|
||||
|
||||
temp_file_manager = _resumable_file_manager
|
||||
if os.path.exists(incomplete_path):
|
||||
resume_size = os.stat(incomplete_path).st_size
|
||||
else:
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
print(
|
||||
"%s not found in cache or force_download set to True, downloading to %s",
|
||||
url,
|
||||
temp_file.name,
|
||||
)
|
||||
|
||||
http_get(
|
||||
url,
|
||||
temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def url_to_filename(url, etag=None):
|
||||
|
||||
url_bytes = url.encode("utf-8")
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
|
||||
if etag:
|
||||
etag_bytes = etag.encode("utf-8")
|
||||
etag_hash = sha256(etag_bytes)
|
||||
filename += "." + etag_hash.hexdigest()
|
||||
|
||||
if url.endswith(".h5"):
|
||||
filename += ".h5"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def cached_path(
|
||||
url_or_filename,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
resume_download=False,
|
||||
user_agent=None,
|
||||
extract_compressed_file=False,
|
||||
force_extract=False,
|
||||
local_files_only=False,
|
||||
):
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(url_or_filename, Path):
|
||||
url_or_filename = str(url_or_filename)
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
if is_remote_url(url_or_filename):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
output_path = get_from_cache(
|
||||
url_or_filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
output_path = url_or_filename
|
||||
elif urlparse(url_or_filename).scheme == "":
|
||||
# File, but it doesn't exist.
|
||||
raise EnvironmentError("file {} not found".format(url_or_filename))
|
||||
else:
|
||||
# Something unknown
|
||||
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
||||
|
||||
if extract_compressed_file:
|
||||
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
||||
return output_path
|
||||
|
||||
# Path where we extract compressed archives
|
||||
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
||||
output_dir, output_file = os.path.split(output_path)
|
||||
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
||||
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
||||
|
||||
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
||||
return output_path_extracted
|
||||
|
||||
# Prevent parallel extractions
|
||||
lock_path = output_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
||||
os.makedirs(output_path_extracted)
|
||||
if is_zipfile(output_path):
|
||||
with ZipFile(output_path, "r") as zip_file:
|
||||
zip_file.extractall(output_path_extracted)
|
||||
zip_file.close()
|
||||
elif tarfile.is_tarfile(output_path):
|
||||
tar_file = tarfile.open(output_path)
|
||||
tar_file.extractall(output_path_extracted)
|
||||
tar_file.close()
|
||||
else:
|
||||
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
||||
|
||||
return output_path_extracted
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def get_data(query, delim=","):
|
||||
assert isinstance(query, str)
|
||||
if os.path.isfile(query):
|
||||
with open(query) as f:
|
||||
data = eval(f.read())
|
||||
else:
|
||||
req = requests.get(query)
|
||||
try:
|
||||
data = requests.json()
|
||||
except Exception:
|
||||
data = req.content.decode()
|
||||
assert data is not None, "could not connect"
|
||||
try:
|
||||
data = eval(data)
|
||||
except Exception:
|
||||
data = data.split("\n")
|
||||
req.close()
|
||||
return data
|
||||
|
||||
|
||||
def get_image_from_url(url):
|
||||
response = requests.get(url)
|
||||
img = np.array(Image.open(BytesIO(response.content)))
|
||||
return img
|
||||
|
||||
|
||||
# to load legacy frcnn checkpoint from detectron
|
||||
def load_frcnn_pkl_from_url(url):
|
||||
fn = url.split("/")[-1]
|
||||
if fn not in os.listdir(os.getcwd()):
|
||||
wget.download(url)
|
||||
with open(fn, "rb") as stream:
|
||||
weights = pkl.load(stream)
|
||||
model = weights.pop("model")
|
||||
new = {}
|
||||
for k, v in model.items():
|
||||
new[k] = torch.from_numpy(v)
|
||||
if "running_var" in k:
|
||||
zero = torch.Tensor([0])
|
||||
k2 = k.replace("running_var", "num_batches_tracked")
|
||||
new[k2] = zero
|
||||
return new
|
||||
|
||||
|
||||
def get_demo_path():
|
||||
print(f"{os.path.abspath(os.path.join(PATH, os.pardir))}/demo.ipynb")
|
||||
|
||||
|
||||
def img_tensorize(im, input_format="RGB"):
|
||||
assert isinstance(im, str)
|
||||
if os.path.isfile(im):
|
||||
img = cv2.imread(im)
|
||||
else:
|
||||
img = get_image_from_url(im)
|
||||
assert img is not None, f"could not connect to: {im}"
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if input_format == "RGB":
|
||||
img = img[:, :, ::-1]
|
||||
return img
|
||||
|
||||
|
||||
def chunk(images, batch=1):
|
||||
return (images[i : i + batch] for i in range(0, len(images), batch))
|
||||
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
coding=utf-8
|
||||
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
||||
Adapted From Facebook Inc, Detectron2
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.import copy
|
||||
"""
|
||||
import colorsys
|
||||
import io
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.colors as mplc
|
||||
import matplotlib.figure as mplfigure
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
import cv2
|
||||
from utils import img_tensorize
|
||||
|
||||
|
||||
_SMALL_OBJ = 1000
|
||||
|
||||
|
||||
class SingleImageViz:
|
||||
def __init__(
|
||||
self,
|
||||
img,
|
||||
scale=1.2,
|
||||
edgecolor="g",
|
||||
alpha=0.5,
|
||||
linestyle="-",
|
||||
saveas="test_out.jpg",
|
||||
rgb=True,
|
||||
pynb=False,
|
||||
id2obj=None,
|
||||
id2attr=None,
|
||||
pad=0.7,
|
||||
):
|
||||
"""
|
||||
img: an RGB image of shape (H, W, 3).
|
||||
"""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.numpy().astype("np.uint8")
|
||||
if isinstance(img, str):
|
||||
img = img_tensorize(img)
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
width, height = img.shape[1], img.shape[0]
|
||||
fig = mplfigure.Figure(frameon=False)
|
||||
dpi = fig.get_dpi()
|
||||
width_in = (width * scale + 1e-2) / dpi
|
||||
height_in = (height * scale + 1e-2) / dpi
|
||||
fig.set_size_inches(width_in, height_in)
|
||||
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0.0, width)
|
||||
ax.set_ylim(height)
|
||||
|
||||
self.saveas = saveas
|
||||
self.rgb = rgb
|
||||
self.pynb = pynb
|
||||
self.img = img
|
||||
self.edgecolor = edgecolor
|
||||
self.alpha = 0.5
|
||||
self.linestyle = linestyle
|
||||
self.font_size = int(np.sqrt(min(height, width)) * scale // 3)
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.scale = scale
|
||||
self.fig = fig
|
||||
self.ax = ax
|
||||
self.pad = pad
|
||||
self.id2obj = id2obj
|
||||
self.id2attr = id2attr
|
||||
self.canvas = FigureCanvasAgg(fig)
|
||||
|
||||
def add_box(self, box, color=None):
|
||||
if color is None:
|
||||
color = self.edgecolor
|
||||
(x0, y0, x1, y1) = box
|
||||
width = x1 - x0
|
||||
height = y1 - y0
|
||||
self.ax.add_patch(
|
||||
mpl.patches.Rectangle(
|
||||
(x0, y0),
|
||||
width,
|
||||
height,
|
||||
fill=False,
|
||||
edgecolor=color,
|
||||
linewidth=self.font_size // 3,
|
||||
alpha=self.alpha,
|
||||
linestyle=self.linestyle,
|
||||
)
|
||||
)
|
||||
|
||||
def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None):
|
||||
if len(boxes.shape) > 2:
|
||||
boxes = boxes[0]
|
||||
if len(obj_ids.shape) > 1:
|
||||
obj_ids = obj_ids[0]
|
||||
if len(obj_scores.shape) > 1:
|
||||
obj_scores = obj_scores[0]
|
||||
if len(attr_ids.shape) > 1:
|
||||
attr_ids = attr_ids[0]
|
||||
if len(attr_scores.shape) > 1:
|
||||
attr_scores = attr_scores[0]
|
||||
if isinstance(boxes, torch.Tensor):
|
||||
boxes = boxes.numpy()
|
||||
if isinstance(boxes, list):
|
||||
boxes = np.array(boxes)
|
||||
assert isinstance(boxes, np.ndarray)
|
||||
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
||||
sorted_idxs = np.argsort(-areas).tolist()
|
||||
boxes = boxes[sorted_idxs] if boxes is not None else None
|
||||
obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None
|
||||
obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None
|
||||
attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None
|
||||
attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None
|
||||
|
||||
assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))]
|
||||
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
||||
if obj_ids is not None:
|
||||
labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores)
|
||||
for i in range(len(boxes)):
|
||||
color = assigned_colors[i]
|
||||
self.add_box(boxes[i], color)
|
||||
self.draw_labels(labels[i], boxes[i], color)
|
||||
|
||||
def draw_labels(self, label, box, color):
|
||||
x0, y0, x1, y1 = box
|
||||
text_pos = (x0, y0)
|
||||
instance_area = (y1 - y0) * (x1 - x0)
|
||||
small = _SMALL_OBJ * self.scale
|
||||
if instance_area < small or y1 - y0 < 40 * self.scale:
|
||||
if y1 >= self.height - 5:
|
||||
text_pos = (x1, y0)
|
||||
else:
|
||||
text_pos = (x0, y1)
|
||||
|
||||
height_ratio = (y1 - y0) / np.sqrt(self.height * self.width)
|
||||
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
||||
font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
||||
font_size *= 0.75 * self.font_size
|
||||
|
||||
self.draw_text(
|
||||
text=label,
|
||||
position=text_pos,
|
||||
color=lighter_color,
|
||||
)
|
||||
|
||||
def draw_text(
|
||||
self,
|
||||
text,
|
||||
position,
|
||||
color="g",
|
||||
ha="left",
|
||||
):
|
||||
rotation = 0
|
||||
font_size = self.font_size
|
||||
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
||||
color[np.argmax(color)] = max(0.8, np.max(color))
|
||||
bbox = {
|
||||
"facecolor": "black",
|
||||
"alpha": self.alpha,
|
||||
"pad": self.pad,
|
||||
"edgecolor": "none",
|
||||
}
|
||||
x, y = position
|
||||
self.ax.text(
|
||||
x,
|
||||
y,
|
||||
text,
|
||||
size=font_size * self.scale,
|
||||
family="sans-serif",
|
||||
bbox=bbox,
|
||||
verticalalignment="top",
|
||||
horizontalalignment=ha,
|
||||
color=color,
|
||||
zorder=10,
|
||||
rotation=rotation,
|
||||
)
|
||||
|
||||
def save(self, saveas=None):
|
||||
if saveas is None:
|
||||
saveas = self.saveas
|
||||
if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"):
|
||||
cv2.imwrite(
|
||||
saveas,
|
||||
self._get_buffer()[:, :, ::-1],
|
||||
)
|
||||
else:
|
||||
self.fig.savefig(saveas)
|
||||
|
||||
def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores):
|
||||
labels = [self.id2obj[i] for i in classes]
|
||||
attr_labels = [self.id2attr[i] for i in attr_classes]
|
||||
labels = [
|
||||
f"{label} {score:.2f} {attr} {attr_score:.2f}"
|
||||
for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores)
|
||||
]
|
||||
return labels
|
||||
|
||||
def _create_text_labels(self, classes, scores):
|
||||
labels = [self.id2obj[i] for i in classes]
|
||||
if scores is not None:
|
||||
if labels is None:
|
||||
labels = ["{:.0f}%".format(s * 100) for s in scores]
|
||||
else:
|
||||
labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)]
|
||||
return labels
|
||||
|
||||
def _random_color(self, maximum=255):
|
||||
idx = np.random.randint(0, len(_COLORS))
|
||||
ret = _COLORS[idx] * maximum
|
||||
if not self.rgb:
|
||||
ret = ret[::-1]
|
||||
return ret
|
||||
|
||||
def _get_buffer(self):
|
||||
if not self.pynb:
|
||||
s, (width, height) = self.canvas.print_to_buffer()
|
||||
if (width, height) != (self.width, self.height):
|
||||
img = cv2.resize(self.img, (width, height))
|
||||
else:
|
||||
img = self.img
|
||||
else:
|
||||
buf = io.BytesIO() # works for cairo backend
|
||||
self.canvas.print_rgba(buf)
|
||||
width, height = self.width, self.height
|
||||
s = buf.getvalue()
|
||||
img = self.img
|
||||
|
||||
buffer = np.frombuffer(s, dtype="uint8")
|
||||
img_rgba = buffer.reshape(height, width, 4)
|
||||
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
||||
|
||||
try:
|
||||
import numexpr as ne # fuse them with numexpr
|
||||
|
||||
visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
|
||||
except ImportError:
|
||||
alpha = alpha.astype("float32") / 255.0
|
||||
visualized_image = img * (1 - alpha) + rgb * alpha
|
||||
|
||||
return visualized_image.astype("uint8")
|
||||
|
||||
def _change_color_brightness(self, color, brightness_factor):
|
||||
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
||||
color = mplc.to_rgb(color)
|
||||
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
||||
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
||||
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
||||
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
||||
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
|
||||
return modified_color
|
||||
|
||||
|
||||
# Color map
|
||||
_COLORS = (
|
||||
np.array(
|
||||
[
|
||||
0.000,
|
||||
0.447,
|
||||
0.741,
|
||||
0.850,
|
||||
0.325,
|
||||
0.098,
|
||||
0.929,
|
||||
0.694,
|
||||
0.125,
|
||||
0.494,
|
||||
0.184,
|
||||
0.556,
|
||||
0.466,
|
||||
0.674,
|
||||
0.188,
|
||||
0.301,
|
||||
0.745,
|
||||
0.933,
|
||||
0.635,
|
||||
0.078,
|
||||
0.184,
|
||||
0.300,
|
||||
0.300,
|
||||
0.300,
|
||||
0.600,
|
||||
0.600,
|
||||
0.600,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.749,
|
||||
0.749,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.333,
|
||||
0.000,
|
||||
0.333,
|
||||
0.667,
|
||||
0.000,
|
||||
0.333,
|
||||
1.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.333,
|
||||
0.000,
|
||||
0.667,
|
||||
0.667,
|
||||
0.000,
|
||||
0.667,
|
||||
1.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.333,
|
||||
0.500,
|
||||
0.000,
|
||||
0.667,
|
||||
0.500,
|
||||
0.000,
|
||||
1.000,
|
||||
0.500,
|
||||
0.333,
|
||||
0.000,
|
||||
0.500,
|
||||
0.333,
|
||||
0.333,
|
||||
0.500,
|
||||
0.333,
|
||||
0.667,
|
||||
0.500,
|
||||
0.333,
|
||||
1.000,
|
||||
0.500,
|
||||
0.667,
|
||||
0.000,
|
||||
0.500,
|
||||
0.667,
|
||||
0.333,
|
||||
0.500,
|
||||
0.667,
|
||||
0.667,
|
||||
0.500,
|
||||
0.667,
|
||||
1.000,
|
||||
0.500,
|
||||
1.000,
|
||||
0.000,
|
||||
0.500,
|
||||
1.000,
|
||||
0.333,
|
||||
0.500,
|
||||
1.000,
|
||||
0.667,
|
||||
0.500,
|
||||
1.000,
|
||||
1.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.333,
|
||||
1.000,
|
||||
0.000,
|
||||
0.667,
|
||||
1.000,
|
||||
0.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.000,
|
||||
1.000,
|
||||
0.333,
|
||||
0.333,
|
||||
1.000,
|
||||
0.333,
|
||||
0.667,
|
||||
1.000,
|
||||
0.333,
|
||||
1.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.000,
|
||||
1.000,
|
||||
0.667,
|
||||
0.333,
|
||||
1.000,
|
||||
0.667,
|
||||
0.667,
|
||||
1.000,
|
||||
0.667,
|
||||
1.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.000,
|
||||
1.000,
|
||||
1.000,
|
||||
0.333,
|
||||
1.000,
|
||||
1.000,
|
||||
0.667,
|
||||
1.000,
|
||||
0.333,
|
||||
0.000,
|
||||
0.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.000,
|
||||
0.000,
|
||||
0.833,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.167,
|
||||
0.000,
|
||||
0.000,
|
||||
0.333,
|
||||
0.000,
|
||||
0.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.000,
|
||||
0.000,
|
||||
0.833,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.167,
|
||||
0.000,
|
||||
0.000,
|
||||
0.333,
|
||||
0.000,
|
||||
0.000,
|
||||
0.500,
|
||||
0.000,
|
||||
0.000,
|
||||
0.667,
|
||||
0.000,
|
||||
0.000,
|
||||
0.833,
|
||||
0.000,
|
||||
0.000,
|
||||
1.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.000,
|
||||
0.143,
|
||||
0.143,
|
||||
0.143,
|
||||
0.857,
|
||||
0.857,
|
||||
0.857,
|
||||
1.000,
|
||||
1.000,
|
||||
1.000,
|
||||
]
|
||||
)
|
||||
.astype(np.float32)
|
||||
.reshape(-1, 3)
|
||||
)
|
||||
953
examples/research_projects/movement-pruning/masked_run_glue.py
Normal file
953
examples/research_projects/movement-pruning/masked_run_glue.py
Normal file
@@ -0,0 +1,953 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-pruning Masked BERT on sequence classification on GLUE."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def schedule_threshold(
|
||||
step: int,
|
||||
total_step: int,
|
||||
warmup_steps: int,
|
||||
initial_threshold: float,
|
||||
final_threshold: float,
|
||||
initial_warmup: int,
|
||||
final_warmup: int,
|
||||
final_lambda: float,
|
||||
):
|
||||
if step <= initial_warmup * warmup_steps:
|
||||
threshold = initial_threshold
|
||||
elif step > (total_step - final_warmup * warmup_steps):
|
||||
threshold = final_threshold
|
||||
else:
|
||||
spars_warmup_steps = initial_warmup * warmup_steps
|
||||
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
|
||||
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
|
||||
regu_lambda = final_lambda * threshold / final_threshold
|
||||
return threshold, regu_lambda
|
||||
|
||||
|
||||
def regularization(model: nn.Module, mode: str):
|
||||
regu, counter = 0, 0
|
||||
for name, param in model.named_parameters():
|
||||
if "mask_scores" in name:
|
||||
if mode == "l1":
|
||||
regu += torch.norm(torch.sigmoid(param), p=1) / param.numel()
|
||||
elif mode == "l0":
|
||||
regu += torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1)).sum() / param.numel()
|
||||
else:
|
||||
ValueError("Don't know this mode.")
|
||||
counter += 1
|
||||
return regu / counter
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter(log_dir=args.output_dir)
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "mask_score" in n and p.requires_grad],
|
||||
"lr": args.mask_scores_learning_rate,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "mask_score" not in n and p.requires_grad and not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"lr": args.learning_rate,
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "mask_score" not in n and p.requires_grad and any(nd in n for nd in no_decay)
|
||||
],
|
||||
"lr": args.learning_rate,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
# Distillation
|
||||
if teacher is not None:
|
||||
logger.info(" Training with distillation")
|
||||
|
||||
global_step = 0
|
||||
# Global TopK
|
||||
if args.global_topk:
|
||||
threshold_mem = None
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
try:
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
except ValueError:
|
||||
global_step = 0
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained,
|
||||
int(args.num_train_epochs),
|
||||
desc="Epoch",
|
||||
disable=args.local_rank not in [-1, 0],
|
||||
)
|
||||
set_seed(args) # Added here for reproducibility
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
threshold, regu_lambda = schedule_threshold(
|
||||
step=global_step,
|
||||
total_step=t_total,
|
||||
warmup_steps=args.warmup_steps,
|
||||
final_threshold=args.final_threshold,
|
||||
initial_threshold=args.initial_threshold,
|
||||
final_warmup=args.final_warmup,
|
||||
initial_warmup=args.initial_warmup,
|
||||
final_lambda=args.final_lambda,
|
||||
)
|
||||
# Global TopK
|
||||
if args.global_topk:
|
||||
if threshold == 1.0:
|
||||
threshold = -1e2 # Or an indefinitely low quantity
|
||||
else:
|
||||
if (threshold_mem is None) or (global_step % args.global_topk_frequency_compute == 0):
|
||||
# Sort all the values to get the global topK
|
||||
concat = torch.cat(
|
||||
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
|
||||
)
|
||||
n = concat.numel()
|
||||
kth = max(n - (int(n * threshold) + 1), 1)
|
||||
threshold_mem = concat.kthvalue(kth).values.item()
|
||||
threshold = threshold_mem
|
||||
else:
|
||||
threshold = threshold_mem
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "masked_bert", "xlnet", "albert"] else None
|
||||
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
||||
|
||||
if "masked" in args.model_type:
|
||||
inputs["threshold"] = threshold
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss, logits_stu = outputs # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
# Distillation loss
|
||||
if teacher is not None:
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
|
||||
with torch.no_grad():
|
||||
(logits_tea,) = teacher(
|
||||
input_ids=inputs["input_ids"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_logits = (
|
||||
F.kl_div(
|
||||
input=F.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||
target=F.softmax(logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
# Regularization
|
||||
if args.regularization is not None:
|
||||
regu_ = regularization(model=model, mode=args.regularization)
|
||||
loss = loss + regu_lambda * regu_
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 or (
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
len(epoch_iterator) <= args.gradient_accumulation_steps
|
||||
and (step + 1) == len(epoch_iterator)
|
||||
):
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
tb_writer.add_scalar("threshold", threshold, global_step)
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
tb_writer.add_scalar("parameter_mean/" + name, param.data.mean(), global_step)
|
||||
tb_writer.add_scalar("parameter_std/" + name, param.data.std(), global_step)
|
||||
tb_writer.add_scalar("parameter_min/" + name, param.data.min(), global_step)
|
||||
tb_writer.add_scalar("parameter_max/" + name, param.data.max(), global_step)
|
||||
tb_writer.add_scalar("grad_mean/" + name, param.grad.data.mean(), global_step)
|
||||
tb_writer.add_scalar("grad_std/" + name, param.grad.data.std(), global_step)
|
||||
if args.regularization is not None and "mask_scores" in name:
|
||||
if args.regularization == "l1":
|
||||
perc = (torch.sigmoid(param) > threshold).sum().item() / param.numel()
|
||||
elif args.regularization == "l0":
|
||||
perc = (torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1))).sum().item() / param.numel()
|
||||
tb_writer.add_scalar("retained_weights_perc/" + name, perc, global_step)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()
|
||||
logs["learning_rate"] = learning_rate_scalar[0]
|
||||
if len(learning_rate_scalar) > 1:
|
||||
for idx, lr in enumerate(learning_rate_scalar[1:]):
|
||||
logs[f"learning_rate/{idx+1}"] = lr
|
||||
logs["loss"] = loss_scalar
|
||||
if teacher is not None:
|
||||
logs["loss/distil"] = loss_logits.item()
|
||||
if args.regularization is not None:
|
||||
logs["loss/regularization"] = regu_.item()
|
||||
if (teacher is not None) or (args.regularization is not None):
|
||||
if (teacher is not None) and (args.regularization is not None):
|
||||
logs["loss/instant_ce"] = (
|
||||
loss.item()
|
||||
- regu_lambda * logs["loss/regularization"]
|
||||
- args.alpha_distil * logs["loss/distil"]
|
||||
) / args.alpha_ce
|
||||
elif teacher is not None:
|
||||
logs["loss/instant_ce"] = (
|
||||
loss.item() - args.alpha_distil * logs["loss/distil"]
|
||||
) / args.alpha_ce
|
||||
else:
|
||||
logs["loss/instant_ce"] = loss.item() - regu_lambda * logs["loss/regularization"]
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "/MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
|
||||
# Global TopK
|
||||
if args.global_topk:
|
||||
threshold_mem = None
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "masked_bert", "xlnet", "albert"] else None
|
||||
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
||||
if "masked" in args.model_type:
|
||||
inputs["threshold"] = args.final_threshold
|
||||
if args.global_topk:
|
||||
if threshold_mem is None:
|
||||
concat = torch.cat(
|
||||
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
|
||||
)
|
||||
n = concat.numel()
|
||||
kth = max(n - (int(n * args.final_threshold) + 1), 1)
|
||||
threshold_mem = concat.kthvalue(kth).values.item()
|
||||
inputs["threshold"] = threshold_mem
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
from scipy.special import softmax
|
||||
|
||||
probs = softmax(preds, axis=-1)
|
||||
entropy = np.exp((-probs * np.log(probs)).sum(axis=-1).mean())
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
if entropy is not None:
|
||||
result["eval_avg_entropy"] = entropy
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
max_length=args.max_seq_length,
|
||||
label_list=label_list,
|
||||
output_mode=output_mode,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training",
|
||||
action="store_true",
|
||||
help="Run evaluation during training at each logging step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case",
|
||||
action="store_true",
|
||||
help="Set this flag if you are using an uncased model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--per_gpu_train_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
|
||||
# Pruning parameters
|
||||
parser.add_argument(
|
||||
"--mask_scores_learning_rate",
|
||||
default=1e-2,
|
||||
type=float,
|
||||
help="The Adam initial learning rate of the mask scores.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_threshold", default=1.0, type=float, help="Initial value of the threshold (for scheduling)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--final_threshold", default=0.7, type=float, help="Final value of the threshold (for scheduling)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_warmup",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
|
||||
"at its `initial_threshold` value (sparsity schedule).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--final_warmup",
|
||||
default=2,
|
||||
type=int,
|
||||
help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
|
||||
"at its final_threshold value (sparsity schedule).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pruning_method",
|
||||
default="topK",
|
||||
type=str,
|
||||
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask_init",
|
||||
default="constant",
|
||||
type=str,
|
||||
help="Initialization method for the mask scores. Choices: constant, uniform, kaiming.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask_scale", default=0.0, type=float, help="Initialization parameter for the chosen initialization method."
|
||||
)
|
||||
|
||||
parser.add_argument("--regularization", default=None, help="Add L0 or L1 regularization to the mask scores.")
|
||||
parser.add_argument(
|
||||
"--final_lambda",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Regularization intensity (used in conjunction with `regularization`.",
|
||||
)
|
||||
|
||||
parser.add_argument("--global_topk", action="store_true", help="Global TopK on the Scores.")
|
||||
parser.add_argument(
|
||||
"--global_topk_frequency_compute",
|
||||
default=25,
|
||||
type=int,
|
||||
help="Frequency at which we compute the TopK global threshold.",
|
||||
)
|
||||
|
||||
# Distillation parameters (optional)
|
||||
parser.add_argument(
|
||||
"--teacher_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacher_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the already fine-tuned teacher model. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Cross entropy loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_distil", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir",
|
||||
action="store_true",
|
||||
help="Overwrite the content of the output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache",
|
||||
action="store_true",
|
||||
help="Overwrite the cached training and evaluation sets",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Regularization
|
||||
if args.regularization == "null":
|
||||
args.regularization = None
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
pruning_method=args.pruning_method,
|
||||
mask_init=args.mask_init,
|
||||
mask_scale=args.mask_scale,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
do_lower_case=args.do_lower_case,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.teacher_type is not None:
|
||||
assert args.teacher_name_or_path is not None
|
||||
assert args.alpha_distil > 0.0
|
||||
assert args.alpha_distil + args.alpha_ce > 0.0
|
||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path)
|
||||
teacher = teacher_model_class.from_pretrained(
|
||||
args.teacher_name_or_path,
|
||||
from_tf=False,
|
||||
config=teacher_config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
teacher.to(args.device)
|
||||
else:
|
||||
teacher = None
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1133
examples/research_projects/movement-pruning/masked_run_squad.py
Normal file
1133
examples/research_projects/movement-pruning/masked_run_squad.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
||||
torch>=1.4.0
|
||||
-e git+https://github.com/huggingface/transformers.git@352d5472b0c1dec0f420d606d16747d851b4bda8#egg=transformers
|
||||
knockknock>=0.1.8.1
|
||||
h5py>=2.10.0
|
||||
numpy>=1.18.2
|
||||
scipy>=1.4.1
|
||||
Reference in New Issue
Block a user