Reorganize examples (#9010)

* Reorganize example folder

* Continue reorganization

* Change requirements for tests

* Final cleanup

* Finish regroup with tests all passing

* Copyright

* Requirements and readme

* Make a full link for the documentation

* Address review comments

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Add symlink

* Reorg again

* Apply suggestions from code review

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Adapt title

* Update to new strucutre

* Remove test

* Update READMEs

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2020-12-11 10:07:02 -05:00
committed by GitHub
parent 86896de064
commit 783d7d2629
215 changed files with 4454 additions and 1193 deletions

View File

@@ -0,0 +1,185 @@
# Movement Pruning: Adaptive Sparsity by Fine-Tuning
Author: @VictorSanh
*Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; however, it is less effective in the transfer learning regime that has become standard for state-of-the-art natural language processing applications. We propose the use of *movement pruning*, a simple, deterministic first-order weight pruning method that is more adaptive to pretrained model fine-tuning. Experiments show that when pruning large pretrained language models, movement pruning shows significant improvements in high-sparsity regimes. When combined with distillation, the approach achieves minimal accuracy loss with down to only 3% of the model parameters:*
| Fine-pruning+Distillation<br>(Teacher=BERT-base fine-tuned) | BERT base<br>fine-tuned | Remaining<br>Weights (%) | Magnitude Pruning | L0 Regularization | Movement Pruning | Soft Movement Pruning |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| SQuAD - Dev<br>EM/F1 | 80.4/88.1 | 10%<br>3% | 70.2/80.1<br>45.5/59.6 | 72.4/81.9<br>64.3/75.8 | 75.6/84.3<br>67.5/78.0 | **76.6/84.9**<br>**72.7/82.3** |
| MNLI - Dev<br>acc/MM acc | 84.5/84.9 | 10%<br>3% | 78.3/79.3<br>69.4/70.6 | 78.7/79.7<br>76.0/76.2 | 80.1/80.4<br>76.5/77.4 | **81.2/81.8**<br>**79.5/80.1** |
| QQP - Dev<br>acc/F1 | 91.4/88.4 | 10%<br>3% | 79.8/65.0<br>72.4/57.8 | 88.1/82.8<br>87.0/81.9 | 89.7/86.2<br>86.1/81.5 | **90.2/86.8**<br>**89.1/85.5** |
This page contains information on how to fine-prune pre-trained models such as `BERT` to obtain extremely sparse models with movement pruning. In contrast to magnitude pruning which selects weights that are far from 0, movement pruning retains weights that are moving away from 0.
For more information, we invite you to check out [our paper](https://arxiv.org/abs/2005.07683).
You can also have a look at this fun *Explain Like I'm Five* introductory [slide deck](https://www.slideshare.net/VictorSanh/movement-pruning-explain-like-im-five-234205241).
<div align="center">
<img src="https://www.seekpng.com/png/detail/166-1669328_how-to-make-emmental-cheese-at-home-icooker.png" width="400">
</div>
## Extreme sparsity and efficient storage
One promise of extreme pruning is to obtain extremely small models that can be easily sent (and stored) on edge devices. By setting weights to 0., we reduce the amount of information we need to store, and thus decreasing the memory size. We are able to obtain extremely sparse fine-pruned models with movement pruning: ~95% of the dense performance with ~5% of total remaining weights in the BERT encoder.
In [this notebook](https://github.com/huggingface/transformers/blob/master/examples/movement-pruning/Saving_PruneBERT.ipynb), we showcase how we can leverage standard tools that exist out-of-the-box to efficiently store an extremely sparse question answering model (only 6% of total remaining weights in the encoder). We are able to reduce the memory size of the encoder **from the 340MB (the original dense BERT) to 11MB**, without any additional training of the model (every operation is performed *post fine-pruning*). It is sufficiently small to store it on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical) 📎!
While movement pruning does not directly optimize for memory footprint (but rather the number of non-null weights), we hypothetize that further memory compression ratios can be achieved with specific quantization aware trainings (see for instance [Q8BERT](https://arxiv.org/abs/1910.06188), [And the Bit Goes Down](https://arxiv.org/abs/1907.05686) or [Quant-Noise](https://arxiv.org/abs/2004.07320)).
## Fine-pruned models
As examples, we release two English PruneBERT checkpoints (models fine-pruned from a pre-trained `BERT` checkpoint), one on SQuAD and the other on MNLI.
- **`prunebert-base-uncased-6-finepruned-w-distil-squad`**<br/>
Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from `BERT-base-uncased` finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: `pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")`
- **`prunebert-base-uncased-6-finepruned-w-distil-mnli`**<br/>
Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on MNLI. We use an additional distillation signal from `BERT-base-uncased` finetuned on MNLI. The encoder counts 6% of total non-null weights and reaches 80.7 (matched) accuracy. The model can be accessed with: `pruned_bert = BertForSequenceClassification.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-mnli")`
## How to fine-prune?
### Setup
The code relies on the 🤗 Transformers library. In addition to the dependencies listed in the [`examples`](https://github.com/huggingface/transformers/tree/master/examples) folder, you should install a few additional dependencies listed in the `requirements.txt` file: `pip install -r requirements.txt`.
Note that we built our experiments on top of a stabilized version of the library (commit https://github.com/huggingface/transformers/commit/352d5472b0c1dec0f420d606d16747d851b4bda8): we do not guarantee that everything is still compatible with the latest version of the master branch.
### Fine-pruning with movement pruning
Below, we detail how to reproduce the results reported in the paper. We use SQuAD as a running example. Commands (and scripts) can be easily adapted for other tasks.
The following command fine-prunes a pre-trained `BERT-base` on SQuAD using movement pruning towards 15% of remaining weights (85% sparsity). Note that we freeze all the embeddings modules (from their pre-trained value) and only prune the Fully Connected layers in the encoder (12 layers of Transformer Block).
```bash
SERIALIZATION_DIR=<OUTPUT_DIR>
SQUAD_DATA=<SQUAD_DATA>
python examples/movement-pruning/masked_run_squad.py \
--output_dir $SERIALIZATION_DIR \
--data_dir $SQUAD_DATA \
--train_file train-v1.1.json \
--predict_file dev-v1.1.json \
--do_train --do_eval --do_lower_case \
--model_type masked_bert \
--model_name_or_path bert-base-uncased \
--per_gpu_train_batch_size 16 \
--warmup_steps 5400 \
--num_train_epochs 10 \
--learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
--initial_threshold 1 --final_threshold 0.15 \
--initial_warmup 1 --final_warmup 2 \
--pruning_method topK --mask_init constant --mask_scale 0.
```
### Fine-pruning with other methods
We can also explore other fine-pruning methods by changing the `pruning_method` parameter:
Soft movement pruning
```bash
python examples/movement-pruning/masked_run_squad.py \
--output_dir $SERIALIZATION_DIR \
--data_dir $SQUAD_DATA \
--train_file train-v1.1.json \
--predict_file dev-v1.1.json \
--do_train --do_eval --do_lower_case \
--model_type masked_bert \
--model_name_or_path bert-base-uncased \
--per_gpu_train_batch_size 16 \
--warmup_steps 5400 \
--num_train_epochs 10 \
--learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
--initial_threshold 0 --final_threshold 0.1 \
--initial_warmup 1 --final_warmup 2 \
--pruning_method sigmoied_threshold --mask_init constant --mask_scale 0. \
--regularization l1 --final_lambda 400.
```
L0 regularization
```bash
python examples/movement-pruning/masked_run_squad.py \
--output_dir $SERIALIZATION_DIR \
--data_dir $SQUAD_DATA \
--train_file train-v1.1.json \
--predict_file dev-v1.1.json \
--do_train --do_eval --do_lower_case \
--model_type masked_bert \
--model_name_or_path bert-base-uncased \
--per_gpu_train_batch_size 16 \
--warmup_steps 5400 \
--num_train_epochs 10 \
--learning_rate 3e-5 --mask_scores_learning_rate 1e-1 \
--initial_threshold 1. --final_threshold 1. \
--initial_warmup 1 --final_warmup 1 \
--pruning_method l0 --mask_init constant --mask_scale 2.197 \
--regularization l0 --final_lambda 125.
```
Iterative Magnitude Pruning
```bash
python examples/movement-pruning/masked_run_squad.py \
--output_dir ./dbg \
--data_dir examples/distillation/data/squad_data \
--train_file train-v1.1.json \
--predict_file dev-v1.1.json \
--do_train --do_eval --do_lower_case \
--model_type masked_bert \
--model_name_or_path bert-base-uncased \
--per_gpu_train_batch_size 16 \
--warmup_steps 5400 \
--num_train_epochs 10 \
--learning_rate 3e-5 \
--initial_threshold 1 --final_threshold 0.15 \
--initial_warmup 1 --final_warmup 2 \
--pruning_method magnitude
```
### After fine-pruning
**Counting parameters**
Regularization based pruning methods (soft movement pruning and L0 regularization) rely on the penalty to induce sparsity. The multiplicative coefficient controls the sparsity level.
To obtain the effective sparsity level in the encoder, we simply count the number of activated (non-null) weights:
```bash
python examples/movement-pruning/counts_parameters.py \
--pruning_method sigmoied_threshold \
--threshold 0.1 \
--serialization_dir $SERIALIZATION_DIR
```
**Pruning once for all**
Once the model has been fine-pruned, the pruned weights can be set to 0. once for all (reducing the amount of information to store). In our running experiments, we can convert a `MaskedBertForQuestionAnswering` (a BERT model augmented to enable on-the-fly pruning capabilities) to a standard `BertForQuestionAnswering`:
```bash
python examples/movement-pruning/bertarize.py \
--pruning_method sigmoied_threshold \
--threshold 0.1 \
--model_name_or_path $SERIALIZATION_DIR
```
## Hyper-parameters
For reproducibility purposes, we share the detailed results presented in the paper. These [tables](https://docs.google.com/spreadsheets/d/17JgRq_OFFTniUrz6BZWW_87DjFkKXpI1kYDSsseT_7g/edit?usp=sharing) exhaustively describe the individual hyper-parameters used for each data point.
## Inference speed
Early experiments show that even though models fine-pruned with (soft) movement pruning are extremely sparse, they do not benefit from significant improvement in terms of inference speed when using the standard PyTorch inference.
We are currently benchmarking and exploring inference setups specifically for sparse architectures.
In particular, hardware manufacturers are announcing devices that will speedup inference for sparse networks considerably.
## Citation
If you find this resource useful, please consider citing the following paper:
```
@article{sanh2020movement,
title={Movement Pruning: Adaptive Sparsity by Fine-Tuning},
author={Victor Sanh and Thomas Wolf and Alexander M. Rush},
year={2020},
eprint={2005.07683},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@@ -0,0 +1,634 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Saving PruneBERT\n",
"\n",
"\n",
"This notebook aims at showcasing how we can leverage standard tools to save (and load) an extremely sparse model fine-pruned with [movement pruning](https://arxiv.org/abs/2005.07683) (or any other unstructured pruning mehtod).\n",
"\n",
"In this example, we used BERT (base-uncased, but the procedure described here is not specific to BERT and can be applied to a large variety of models.\n",
"\n",
"We first obtain an extremely sparse model by fine-pruning with movement pruning on SQuAD v1.1. We then used the following combination of standard tools:\n",
"- We reduce the precision of the model with Int8 dynamic quantization using [PyTorch implementation](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html). We only quantized the Fully Connected Layers.\n",
"- Sparse quantized matrices are converted into the [Compressed Sparse Row format](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html).\n",
"- We use HDF5 with `gzip` compression to store the weights.\n",
"\n",
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
"\n",
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">\n",
"\n",
"*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Includes\n",
"\n",
"import h5py\n",
"import os\n",
"import json\n",
"from collections import OrderedDict\n",
"\n",
"from scipy import sparse\n",
"import numpy as np\n",
"\n",
"import torch\n",
"from torch import nn\n",
"\n",
"from transformers import *\n",
"\n",
"os.chdir('../../')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dynamic quantization induces little or no loss of performance while significantly reducing the memory footprint."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Load fine-pruned model and quantize the model\n",
"\n",
"model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
"model.to('cpu')\n",
"\n",
"quantized_model = torch.quantization.quantize_dynamic(\n",
" model=model,\n",
" qconfig_spec = {\n",
" torch.nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
" },\n",
" dtype=torch.qint8,\n",
" )\n",
"# print(quantized_model)\n",
"\n",
"qtz_st = quantized_model.state_dict()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Saving the original (encoder + classifier) in the standard torch.save format\n",
"\n",
"dense_st = {name: param for name, param in model.state_dict().items() \n",
" if \"embedding\" not in name and \"pooler\" not in name}\n",
"torch.save(dense_st, 'dbg/dense_squad.pt',)\n",
"dense_mb_size = os.path.getsize(\"dbg/dense_squad.pt\")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Decompose quantization for bert.encoder.layer.0.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.0.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.0.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.0.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.0.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.0.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.1.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.1.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.1.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.1.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.1.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.1.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.2.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.2.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.2.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.2.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.2.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.2.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.3.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.3.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.3.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.3.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.3.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.3.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.4.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.4.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.4.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.4.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.4.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.4.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.5.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.5.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.5.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.5.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.5.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.5.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.6.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.6.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.6.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.6.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.6.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.6.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.7.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.7.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.7.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.7.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.7.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.7.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.8.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.8.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.8.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.8.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.8.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.8.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.9.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.9.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.9.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.9.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.9.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.9.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.10.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.10.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.10.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.10.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.10.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.10.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.11.attention.self.query._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.11.attention.self.key._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.11.attention.self.value._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.11.attention.output.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.11.intermediate.dense._packed_params.weight\n",
"Decompose quantization for bert.encoder.layer.11.output.dense._packed_params.weight\n",
"Decompose quantization for bert.pooler.dense._packed_params.weight\n",
"Decompose quantization for qa_outputs._packed_params.weight\n"
]
}
],
"source": [
"# Elementary representation: we decompose the quantized tensors into (scale, zero_point, int_repr).\n",
"# See https://pytorch.org/docs/stable/quantization.html\n",
"\n",
"# We further leverage the fact that int_repr is sparse matrix to optimize the storage: we decompose int_repr into\n",
"# its CSR representation (data, indptr, indices).\n",
"\n",
"elementary_qtz_st = {}\n",
"for name, param in qtz_st.items():\n",
" if \"dtype\" not in name and param.is_quantized:\n",
" print(\"Decompose quantization for\", name)\n",
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
" zero_point = param.q_zero_point() # torch.tensor(1,) - int32\n",
" elementary_qtz_st[f\"{name}.scale\"] = scale\n",
" elementary_qtz_st[f\"{name}.zero_point\"] = zero_point\n",
"\n",
" # We assume the int_repr is sparse and compute its CSR representation\n",
" # Only the FCs in the encoder are actually sparse\n",
" int_repr = param.int_repr() # torch.tensor(nb_rows, nb_columns) - int8\n",
" int_repr_cs = sparse.csr_matrix(int_repr) # scipy.sparse.csr.csr_matrix\n",
"\n",
" elementary_qtz_st[f\"{name}.int_repr.data\"] = int_repr_cs.data # np.array int8\n",
" elementary_qtz_st[f\"{name}.int_repr.indptr\"] = int_repr_cs.indptr # np.array int32\n",
" assert max(int_repr_cs.indices) < 65535 # If not, we shall fall back to int32\n",
" elementary_qtz_st[f\"{name}.int_repr.indices\"] = np.uint16(int_repr_cs.indices) # np.array uint16\n",
" elementary_qtz_st[f\"{name}.int_repr.shape\"] = int_repr_cs.shape # tuple(int, int)\n",
" else:\n",
" elementary_qtz_st[name] = param\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
"str_2_dtype = {\"qint8\": torch.qint8}\n",
"dtype_2_str = {torch.qint8: \"qint8\"}\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Encoder Size (MB) - Sparse & Quantized - `torch.save`: 21.29\n"
]
}
],
"source": [
"# Saving the pruned (encoder + classifier) in the standard torch.save format\n",
"\n",
"dense_optimized_st = {name: param for name, param in elementary_qtz_st.items() \n",
" if \"embedding\" not in name and \"pooler\" not in name}\n",
"torch.save(dense_optimized_st, 'dbg/dense_squad_optimized.pt',)\n",
"print(\"Encoder Size (MB) - Sparse & Quantized - `torch.save`:\",\n",
" round(os.path.getsize(\"dbg/dense_squad_optimized.pt\")/1e6, 2))\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Skip bert.embeddings.word_embeddings.weight\n",
"Skip bert.embeddings.position_embeddings.weight\n",
"Skip bert.embeddings.token_type_embeddings.weight\n",
"Skip bert.embeddings.LayerNorm.weight\n",
"Skip bert.embeddings.LayerNorm.bias\n",
"Skip bert.pooler.dense.scale\n",
"Skip bert.pooler.dense.zero_point\n",
"Skip bert.pooler.dense._packed_params.weight.scale\n",
"Skip bert.pooler.dense._packed_params.weight.zero_point\n",
"Skip bert.pooler.dense._packed_params.weight.int_repr.data\n",
"Skip bert.pooler.dense._packed_params.weight.int_repr.indptr\n",
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
"Skip bert.pooler.dense._packed_params.bias\n",
"Skip bert.pooler.dense._packed_params.dtype\n",
"\n",
"Encoder Size (MB) - Dense: 340.26\n",
"Encoder Size (MB) - Sparse & Quantized: 11.28\n"
]
}
],
"source": [
"# Save the decomposed state_dict with an HDF5 file\n",
"# Saving only the encoder + QA Head\n",
"\n",
"with h5py.File('dbg/squad_sparse.h5','w') as hf:\n",
" for name, param in elementary_qtz_st.items():\n",
" if \"embedding\" in name:\n",
" print(f\"Skip {name}\")\n",
" continue\n",
"\n",
" if \"pooler\" in name:\n",
" print(f\"Skip {name}\")\n",
" continue\n",
"\n",
" if type(param) == torch.Tensor:\n",
" if param.numel() == 1:\n",
" # module scale\n",
" # module zero_point\n",
" hf.attrs[name] = param\n",
" continue\n",
"\n",
" if param.requires_grad:\n",
" # LayerNorm\n",
" param = param.detach().numpy()\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
" # float - tensor _packed_params.weight.scale\n",
" # int - tensor _packed_params.weight.zero_point\n",
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
" elif type(param) == torch.dtype:\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
" \n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
"\n",
"with open('dbg/metadata.json', 'w') as f:\n",
" f.write(json.dumps(qtz_st._metadata)) \n",
"\n",
"size = os.path.getsize(\"dbg/squad_sparse.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
"print(\"\")\n",
"print(\"Encoder Size (MB) - Dense: \", round(dense_mb_size/1e6, 2))\n",
"print(\"Encoder Size (MB) - Sparse & Quantized:\", round(size/1e6, 2))\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Size (MB): 99.41\n"
]
}
],
"source": [
"# Save the decomposed state_dict to HDF5 storage\n",
"# Save everything in the architecutre (embedding + encoder + QA Head)\n",
"\n",
"with h5py.File('dbg/squad_sparse_with_embs.h5','w') as hf:\n",
" for name, param in elementary_qtz_st.items():\n",
"# if \"embedding\" in name:\n",
"# print(f\"Skip {name}\")\n",
"# continue\n",
"\n",
"# if \"pooler\" in name:\n",
"# print(f\"Skip {name}\")\n",
"# continue\n",
"\n",
" if type(param) == torch.Tensor:\n",
" if param.numel() == 1:\n",
" # module scale\n",
" # module zero_point\n",
" hf.attrs[name] = param\n",
" continue\n",
"\n",
" if param.requires_grad:\n",
" # LayerNorm\n",
" param = param.detach().numpy()\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
" # float - tensor _packed_params.weight.scale\n",
" # int - tensor _packed_params.weight.zero_point\n",
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
" elif type(param) == torch.dtype:\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
" \n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
"\n",
"\n",
"with open('dbg/metadata.json', 'w') as f:\n",
" f.write(json.dumps(qtz_st._metadata)) \n",
"\n",
"size = os.path.getsize(\"dbg/squad_sparse_with_embs.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
"print('\\nSize (MB):', round(size/1e6, 2))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Reconstruct the elementary state dict\n",
"\n",
"reconstructed_elementary_qtz_st = {}\n",
"\n",
"hf = h5py.File('dbg/squad_sparse_with_embs.h5','r')\n",
"\n",
"for attr_name, attr_param in hf.attrs.items():\n",
" if 'shape' in attr_name:\n",
" attr_param = tuple(attr_param)\n",
" elif \".scale\" in attr_name:\n",
" if \"_packed_params\" in attr_name:\n",
" attr_param = float(attr_param)\n",
" else:\n",
" attr_param = torch.tensor(attr_param)\n",
" elif \".zero_point\" in attr_name:\n",
" if \"_packed_params\" in attr_name:\n",
" attr_param = int(attr_param)\n",
" else:\n",
" attr_param = torch.tensor(attr_param)\n",
" elif \".dtype\" in attr_name:\n",
" attr_param = str_2_dtype[attr_param]\n",
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
" # print(f\"Unpack {attr_name}\")\n",
" \n",
"# Get the tensors/arrays\n",
"for data_name, data_param in hf.items():\n",
" if \"LayerNorm\" in data_name or \"_packed_params.bias\" in data_name:\n",
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
" elif \"embedding\" in data_name:\n",
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
" else: # _packed_params.weight.int_repr.data, _packed_params.weight.int_repr.indices and _packed_params.weight.int_repr.indptr\n",
" data_param = np.array(data_param)\n",
" if \"indices\" in data_name:\n",
" data_param = np.array(data_param, dtype=np.int32)\n",
" reconstructed_elementary_qtz_st[data_name] = data_param\n",
" # print(f\"Unpack {data_name}\")\n",
" \n",
"\n",
"hf.close()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Sanity checks\n",
"\n",
"for name, param in reconstructed_elementary_qtz_st.items():\n",
" assert name in elementary_qtz_st\n",
"for name, param in elementary_qtz_st.items():\n",
" assert name in reconstructed_elementary_qtz_st, name\n",
"\n",
"for name, param in reconstructed_elementary_qtz_st.items():\n",
" assert type(param) == type(elementary_qtz_st[name]), name\n",
" if type(param) == torch.Tensor:\n",
" assert torch.all(torch.eq(param, elementary_qtz_st[name])), name\n",
" elif type(param) == np.ndarray:\n",
" assert (param == elementary_qtz_st[name]).all(), name\n",
" else:\n",
" assert param == elementary_qtz_st[name], name"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Re-assemble the sparse int_repr from the CSR format\n",
"\n",
"reconstructed_qtz_st = {}\n",
"\n",
"for name, param in reconstructed_elementary_qtz_st.items():\n",
" if \"weight.int_repr.indptr\" in name:\n",
" prefix_ = name[:-16]\n",
" data = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.data\"]\n",
" indptr = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indptr\"]\n",
" indices = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indices\"]\n",
" shape = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.shape\"]\n",
"\n",
" int_repr = sparse.csr_matrix(arg1=(data, indices, indptr),\n",
" shape=shape)\n",
" int_repr = torch.tensor(int_repr.todense())\n",
"\n",
" scale = reconstructed_elementary_qtz_st[f\"{prefix_}.scale\"]\n",
" zero_point = reconstructed_elementary_qtz_st[f\"{prefix_}.zero_point\"]\n",
" weight = torch._make_per_tensor_quantized_tensor(int_repr,\n",
" scale,\n",
" zero_point)\n",
"\n",
" reconstructed_qtz_st[f\"{prefix_}\"] = weight\n",
" elif \"int_repr.data\" in name or \"int_repr.shape\" in name or \"int_repr.indices\" in name or \\\n",
" \"weight.scale\" in name or \"weight.zero_point\" in name:\n",
" continue\n",
" else:\n",
" reconstructed_qtz_st[name] = param\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Sanity checks\n",
"\n",
"for name, param in reconstructed_qtz_st.items():\n",
" assert name in qtz_st\n",
"for name, param in qtz_st.items():\n",
" assert name in reconstructed_qtz_st, name\n",
"\n",
"for name, param in reconstructed_qtz_st.items():\n",
" assert type(param) == type(qtz_st[name]), name\n",
" if type(param) == torch.Tensor:\n",
" assert torch.all(torch.eq(param, qtz_st[name])), name\n",
" elif type(param) == np.ndarray:\n",
" assert (param == qtz_st[name]).all(), name\n",
" else:\n",
" assert param == qtz_st[name], name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sanity checks"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load the re-constructed state dict into a model\n",
"\n",
"dummy_model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')\n",
"dummy_model.to('cpu')\n",
"\n",
"reconstructed_qtz_model = torch.quantization.quantize_dynamic(\n",
" model=dummy_model,\n",
" qconfig_spec = None,\n",
" dtype=torch.qint8,\n",
" )\n",
"\n",
"reconstructed_qtz_st = OrderedDict(reconstructed_qtz_st)\n",
"with open('dbg/metadata.json', 'r') as read_file:\n",
" metadata = json.loads(read_file.read())\n",
"reconstructed_qtz_st._metadata = metadata\n",
"\n",
"reconstructed_qtz_model.load_state_dict(reconstructed_qtz_st)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sanity check passed\n"
]
}
],
"source": [
"# Sanity checks on the infernce\n",
"\n",
"N = 32\n",
"\n",
"for _ in range(25):\n",
" inputs = torch.randint(low=0, high=30000, size=(N, 128))\n",
" mask = torch.ones(size=(N, 128))\n",
"\n",
" y_reconstructed = reconstructed_qtz_model(input_ids=inputs, attention_mask=mask)[0]\n",
" y = quantized_model(input_ids=inputs, attention_mask=mask)[0]\n",
" \n",
" assert torch.all(torch.eq(y, y_reconstructed))\n",
"print(\"Sanity check passed\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,132 @@
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all.
For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded)
as a standard :class:`~transformers.BertForSequenceClassification`.
"""
import argparse
import os
import shutil
import torch
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
def main(args):
pruning_method = args.pruning_method
threshold = args.threshold
model_name_or_path = args.model_name_or_path.rstrip("/")
target_model_path = args.target_model_path
print(f"Load fine-pruned model from {model_name_or_path}")
model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
pruned_model = {}
for name, tensor in model.items():
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
pruned_model[name] = tensor
print(f"Copied layer {name}")
elif "classifier" in name or "qa_output" in name:
pruned_model[name] = tensor
print(f"Copied layer {name}")
elif "bias" in name:
pruned_model[name] = tensor
print(f"Copied layer {name}")
else:
if pruning_method == "magnitude":
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "topK":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
mask = TopKBinarizer.apply(scores, threshold)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "sigmoied_threshold":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
mask = ThresholdBinarizer.apply(scores, threshold, True)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "l0":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
l, r = -0.1, 1.1
s = torch.sigmoid(scores)
s_bar = s * (r - l) + l
mask = s_bar.clamp(min=0.0, max=1.0)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
else:
raise ValueError("Unknown pruning method")
if target_model_path is None:
target_model_path = os.path.join(
os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}"
)
if not os.path.isdir(target_model_path):
shutil.copytree(model_name_or_path, target_model_path)
print(f"\nCreated folder {target_model_path}")
torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin"))
print("\nPruned model saved! See you later!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pruning_method",
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
type=str,
required=True,
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
)
parser.add_argument(
"--threshold",
type=float,
required=False,
help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
"Not needed for `l0`",
)
parser.add_argument(
"--model_name_or_path",
type=str,
required=True,
help="Folder containing the model that was previously fine-pruned",
)
parser.add_argument(
"--target_model_path",
default=None,
type=str,
required=False,
help="Folder containing the model that was previously fine-pruned",
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,92 @@
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
"""
import argparse
import os
import torch
from emmental.modules import ThresholdBinarizer, TopKBinarizer
def main(args):
serialization_dir = args.serialization_dir
pruning_method = args.pruning_method
threshold = args.threshold
st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu")
remaining_count = 0 # Number of remaining (not pruned) params in the encoder
encoder_count = 0 # Number of params in the encoder
print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight")
for name, param in st.items():
if "encoder" not in name:
continue
if "mask_scores" in name:
if pruning_method == "topK":
mask_ones = TopKBinarizer.apply(param, threshold).sum().item()
elif pruning_method == "sigmoied_threshold":
mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item()
elif pruning_method == "l0":
l, r = -0.1, 1.1
s = torch.sigmoid(param)
s_bar = s * (r - l) + l
mask = s_bar.clamp(min=0.0, max=1.0)
mask_ones = (mask > 0.0).sum().item()
else:
raise ValueError("Unknown pruning method")
remaining_count += mask_ones
print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones))
else:
encoder_count += param.numel()
if "bias" in name or "LayerNorm" in name:
remaining_count += param.numel()
print("")
print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pruning_method",
choices=["l0", "topK", "sigmoied_threshold"],
type=str,
required=True,
help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
)
parser.add_argument(
"--threshold",
type=float,
required=False,
help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
"Not needed for `l0`",
)
parser.add_argument(
"--serialization_dir",
type=str,
required=True,
help="Folder containing the model that was previously fine-pruned",
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,10 @@
# flake8: noqa
from .configuration_bert_masked import MaskedBertConfig
from .modeling_bert_masked import (
MaskedBertForMultipleChoice,
MaskedBertForQuestionAnswering,
MaskedBertForSequenceClassification,
MaskedBertForTokenClassification,
MaskedBertModel,
)
from .modules import *

View File

@@ -0,0 +1,71 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Masked BERT model configuration. It replicates the class `~transformers.BertConfig`
and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init` and `mask_scale`."""
import logging
from transformers.configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
class MaskedBertConfig(PretrainedConfig):
"""
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
"""
model_type = "masked_bert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
pruning_method="topK",
mask_init="constant",
mask_scale=0.0,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.pruning_method = pruning_method
self.mask_init = mask_init
self.mask_scale = mask_scale

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
# flake8: noqa
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
from .masked_nn import MaskedLinear

View File

@@ -0,0 +1,144 @@
# coding=utf-8
# Copyright 2020-present, AllenAI Authors, University of Illinois Urbana-Champaign,
# Intel Nervana Systems and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Binarizers take a (real value) matrix as input and produce a binary (values in {0,1}) mask of the same shape.
"""
import torch
from torch import autograd
class ThresholdBinarizer(autograd.Function):
"""
Thresholdd binarizer.
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j} > \tau`
where `\tau` is a real value threshold.
Implementation is inspired from:
https://github.com/arunmallya/piggyback
Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights
Arun Mallya, Dillon Davis, Svetlana Lazebnik
"""
@staticmethod
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
"""
Args:
inputs (`torch.FloatTensor`)
The input matrix from which the binarizer computes the binary mask.
threshold (`float`)
The threshold value (in R).
sigmoid (`bool`)
If set to ``True``, we apply the sigmoid function to the `inputs` matrix before comparing to `threshold`.
In this case, `threshold` should be a value between 0 and 1.
Returns:
mask (`torch.FloatTensor`)
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
retained, 0 - the associated weight is pruned).
"""
nb_elems = inputs.numel()
nb_min = int(0.005 * nb_elems) + 1
if sigmoid:
mask = (torch.sigmoid(inputs) > threshold).type(inputs.type())
else:
mask = (inputs > threshold).type(inputs.type())
if mask.sum() < nb_min:
# We limit the pruning so that at least 0.5% (half a percent) of the weights are remaining
k_threshold = inputs.flatten().kthvalue(max(nb_elems - nb_min, 1)).values
mask = (inputs > k_threshold).type(inputs.type())
return mask
@staticmethod
def backward(ctx, gradOutput):
return gradOutput, None, None
class TopKBinarizer(autograd.Function):
"""
Top-k Binarizer.
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
is among the k% highest values of S.
Implementation is inspired from:
https://github.com/allenai/hidden-networks
What's hidden in a randomly weighted neural network?
Vivek Ramanujan*, Mitchell Wortsman*, Aniruddha Kembhavi, Ali Farhadi, Mohammad Rastegari
"""
@staticmethod
def forward(ctx, inputs: torch.tensor, threshold: float):
"""
Args:
inputs (`torch.FloatTensor`)
The input matrix from which the binarizer computes the binary mask.
threshold (`float`)
The percentage of weights to keep (the rest is pruned).
`threshold` is a float between 0 and 1.
Returns:
mask (`torch.FloatTensor`)
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
retained, 0 - the associated weight is pruned).
"""
# Get the subnetwork by sorting the inputs and using the top threshold %
mask = inputs.clone()
_, idx = inputs.flatten().sort(descending=True)
j = int(threshold * inputs.numel())
# flat_out and mask access the same memory.
flat_out = mask.flatten()
flat_out[idx[j:]] = 0
flat_out[idx[:j]] = 1
return mask
@staticmethod
def backward(ctx, gradOutput):
return gradOutput, None
class MagnitudeBinarizer(object):
"""
Magnitude Binarizer.
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
is among the k% highest values of |S| (absolute value).
Implementation is inspired from https://github.com/NervanaSystems/distiller/blob/2291fdcc2ea642a98d4e20629acb5a9e2e04b4e6/distiller/pruning/automated_gradual_pruner.py#L24
"""
@staticmethod
def apply(inputs: torch.tensor, threshold: float):
"""
Args:
inputs (`torch.FloatTensor`)
The input matrix from which the binarizer computes the binary mask.
This input marix is typically the weight matrix.
threshold (`float`)
The percentage of weights to keep (the rest is pruned).
`threshold` is a float between 0 and 1.
Returns:
mask (`torch.FloatTensor`)
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
retained, 0 - the associated weight is pruned).
"""
# Get the subnetwork by sorting the inputs and using the top threshold %
mask = inputs.clone()
_, idx = inputs.abs().flatten().sort(descending=True)
j = int(threshold * inputs.numel())
# flat_out and mask access the same memory.
flat_out = mask.flatten()
flat_out[idx[j:]] = 0
flat_out[idx[:j]] = 1
return mask

View File

@@ -0,0 +1,107 @@
# coding=utf-8
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Masked Linear module: A fully connected layer that computes an adaptive binary mask on the fly.
The mask (binary or not) is computed at each forward pass and multiplied against
the weight matrix to prune a portion of the weights.
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
"""
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
class MaskedLinear(nn.Linear):
"""
Fully Connected layer with on the fly adaptive mask.
If needed, a score matrix is created to store the importance of each associated weight.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
mask_init: str = "constant",
mask_scale: float = 0.0,
pruning_method: str = "topK",
):
"""
Args:
in_features (`int`)
Size of each input sample
out_features (`int`)
Size of each output sample
bias (`bool`)
If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
mask_init (`str`)
The initialization method for the score matrix if a score matrix is needed.
Choices: ["constant", "uniform", "kaiming"]
Default: ``constant``
mask_scale (`float`)
The initialization parameter for the chosen initialization method `mask_init`.
Default: ``0.``
pruning_method (`str`)
Method to compute the mask.
Choices: ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"]
Default: ``topK``
"""
super(MaskedLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
assert pruning_method in ["topK", "threshold", "sigmoied_threshold", "magnitude", "l0"]
self.pruning_method = pruning_method
if self.pruning_method in ["topK", "threshold", "sigmoied_threshold", "l0"]:
self.mask_scale = mask_scale
self.mask_init = mask_init
self.mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.init_mask()
def init_mask(self):
if self.mask_init == "constant":
init.constant_(self.mask_scores, val=self.mask_scale)
elif self.mask_init == "uniform":
init.uniform_(self.mask_scores, a=-self.mask_scale, b=self.mask_scale)
elif self.mask_init == "kaiming":
init.kaiming_uniform_(self.mask_scores, a=math.sqrt(5))
def forward(self, input: torch.tensor, threshold: float):
# Get the mask
if self.pruning_method == "topK":
mask = TopKBinarizer.apply(self.mask_scores, threshold)
elif self.pruning_method in ["threshold", "sigmoied_threshold"]:
sig = "sigmoied" in self.pruning_method
mask = ThresholdBinarizer.apply(self.mask_scores, threshold, sig)
elif self.pruning_method == "magnitude":
mask = MagnitudeBinarizer.apply(self.weight, threshold)
elif self.pruning_method == "l0":
l, r, b = -0.1, 1.1, 2 / 3
if self.training:
u = torch.zeros_like(self.mask_scores).uniform_().clamp(0.0001, 0.9999)
s = torch.sigmoid((u.log() - (1 - u).log() + self.mask_scores) / b)
else:
s = torch.sigmoid(self.mask_scores)
s_bar = s * (r - l) + l
mask = s_bar.clamp(min=0.0, max=1.0)
# Mask weights with computed mask
weight_thresholded = mask * self.weight
# Compute output (linear layer) with masked weights
return F.linear(input, weight_thresholded, self.bias)

View File

@@ -0,0 +1,5 @@
# LXMERT DEMO
1. make a virtualenv: ``virtualenv venv`` and activate ``source venv/bin/activate``
2. install reqs: ``pip install -r ./requirements.txt``
3. usage is as shown in demo.ipynb

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,149 @@
import getopt
import json
import os
# import numpy as np
import sys
from collections import OrderedDict
import datasets
import numpy as np
import torch
from modeling_frcnn import GeneralizedRCNN
from processing_image import Preprocess
from utils import Config
"""
USAGE:
``python extracting_data.py -i <img_dir> -o <dataset_file>.datasets <batch_size>``
"""
TEST = False
CONFIG = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
DEFAULT_SCHEMA = datasets.Features(
OrderedDict(
{
"attr_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
"attr_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
"boxes": datasets.Array2D((CONFIG.MAX_DETECTIONS, 4), dtype="float32"),
"img_id": datasets.Value("int32"),
"obj_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
"obj_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")),
"roi_features": datasets.Array2D((CONFIG.MAX_DETECTIONS, 2048), dtype="float32"),
"sizes": datasets.Sequence(length=2, feature=datasets.Value("float32")),
"preds_per_image": datasets.Value(dtype="int32"),
}
)
)
class Extract:
def __init__(self, argv=sys.argv[1:]):
inputdir = None
outputfile = None
subset_list = None
batch_size = 1
opts, args = getopt.getopt(argv, "i:o:b:s", ["inputdir=", "outfile=", "batch_size=", "subset_list="])
for opt, arg in opts:
if opt in ("-i", "--inputdir"):
inputdir = arg
elif opt in ("-o", "--outfile"):
outputfile = arg
elif opt in ("-b", "--batch_size"):
batch_size = int(arg)
elif opt in ("-s", "--subset_list"):
subset_list = arg
assert inputdir is not None # and os.path.isdir(inputdir), f"{inputdir}"
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}"
if subset_list is not None:
with open(os.path.realpath(subset_list)) as f:
self.subset_list = set(map(lambda x: self._vqa_file_split()[0], tryload(f)))
else:
self.subset_list = None
self.config = CONFIG
if torch.cuda.is_available():
self.config.model.device = "cuda"
self.inputdir = os.path.realpath(inputdir)
self.outputfile = os.path.realpath(outputfile)
self.preprocess = Preprocess(self.config)
self.model = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.config)
self.batch = batch_size if batch_size != 0 else 1
self.schema = DEFAULT_SCHEMA
def _vqa_file_split(self, file):
img_id = int(file.split(".")[0].split("_")[-1])
filepath = os.path.join(self.inputdir, file)
return (img_id, filepath)
@property
def file_generator(self):
batch = []
for i, file in enumerate(os.listdir(self.inputdir)):
if self.subset_list is not None and i not in self.subset_list:
continue
batch.append(self._vqa_file_split(file))
if len(batch) == self.batch:
temp = batch
batch = []
yield list(map(list, zip(*temp)))
for i in range(1):
yield list(map(list, zip(*batch)))
def __call__(self):
# make writer
if not TEST:
writer = datasets.ArrowWriter(features=self.schema, path=self.outputfile)
# do file generator
for i, (img_ids, filepaths) in enumerate(self.file_generator):
images, sizes, scales_yx = self.preprocess(filepaths)
output_dict = self.model(
images,
sizes,
scales_yx=scales_yx,
padding="max_detections",
max_detections=self.config.MAX_DETECTIONS,
pad_value=0,
return_tensors="np",
location="cpu",
)
output_dict["boxes"] = output_dict.pop("normalized_boxes")
if not TEST:
output_dict["img_id"] = np.array(img_ids)
batch = self.schema.encode_batch(output_dict)
writer.write_batch(batch)
if TEST:
break
# finalizer the writer
if not TEST:
num_examples, num_bytes = writer.finalize()
print(f"Success! You wrote {num_examples} entry(s) and {num_bytes >> 20} mb")
def tryload(stream):
try:
data = json.load(stream)
try:
data = list(data.keys())
except Exception:
data = [d["img_id"] for d in data]
except Exception:
try:
data = eval(stream.read())
except Exception:
data = stream.read().split("\n")
return data
if __name__ == "__main__":
extract = Extract(sys.argv[1:])
extract()
if not TEST:
dataset = datasets.Dataset.from_file(extract.outputfile)
# wala!
# print(np.array(dataset[0:2]["roi_features"]).shape)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,147 @@
"""
coding=utf-8
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
Adapted From Facebook Inc, Detectron2
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.import copy
"""
import sys
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from utils import img_tensorize
class ResizeShortestEdge:
def __init__(self, short_edge_length, max_size=sys.maxsize):
"""
Args:
short_edge_length (list[min, max])
max_size (int): maximum allowed longest edge length.
"""
self.interp_method = "bilinear"
self.max_size = max_size
self.short_edge_length = short_edge_length
def __call__(self, imgs):
img_augs = []
for img in imgs:
h, w = img.shape[:2]
# later: provide list and randomly choose index for resize
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
if size == 0:
return img
scale = size * 1.0 / min(h, w)
if h < w:
newh, neww = size, scale * w
else:
newh, neww = scale * h, size
if max(newh, neww) > self.max_size:
scale = self.max_size * 1.0 / max(newh, neww)
newh = newh * scale
neww = neww * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
if img.dtype == np.uint8:
pil_image = Image.fromarray(img)
pil_image = pil_image.resize((neww, newh), Image.BILINEAR)
img = np.asarray(pil_image)
else:
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
img_augs.append(img)
return img_augs
class Preprocess:
def __init__(self, cfg):
self.aug = ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
self.input_format = cfg.INPUT.FORMAT
self.size_divisibility = cfg.SIZE_DIVISIBILITY
self.pad_value = cfg.PAD_VALUE
self.max_image_size = cfg.INPUT.MAX_SIZE_TEST
self.device = cfg.MODEL.DEVICE
self.pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
self.pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
def pad(self, images):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
image_sizes = [im.shape[-2:] for im in images]
images = [
F.pad(
im,
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
value=self.pad_value,
)
for size, im in zip(image_sizes, images)
]
return torch.stack(images), torch.tensor(image_sizes)
def __call__(self, images, single_image=False):
with torch.no_grad():
if not isinstance(images, list):
images = [images]
if single_image:
assert len(images) == 1
for i in range(len(images)):
if isinstance(images[i], torch.Tensor):
images.insert(i, images.pop(i).to(self.device).float())
elif not isinstance(images[i], torch.Tensor):
images.insert(
i,
torch.as_tensor(img_tensorize(images.pop(i), input_format=self.input_format))
.to(self.device)
.float(),
)
# resize smallest edge
raw_sizes = torch.tensor([im.shape[:2] for im in images])
images = self.aug(images)
# transpose images and convert to torch tensors
# images = [torch.as_tensor(i.astype("float32")).permute(2, 0, 1).to(self.device) for i in images]
# now normalize before pad to avoid useless arithmetic
images = [self.normalizer(x) for x in images]
# now pad them to do the following operations
images, sizes = self.pad(images)
# Normalize
if self.size_divisibility > 0:
raise NotImplementedError()
# pad
scales_yx = torch.true_divide(raw_sizes, sizes)
if single_image:
return images[0], sizes[0], scales_yx[0]
else:
return images, sizes, scales_yx
def _scale_box(boxes, scale_yx):
boxes[:, 0::2] *= scale_yx[:, 1]
boxes[:, 1::2] *= scale_yx[:, 0]
return boxes
def _clip_box(tensor, box_size: Tuple[int, int]):
assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
h, w = box_size
tensor[:, 0].clamp_(min=0, max=w)
tensor[:, 1].clamp_(min=0, max=h)
tensor[:, 2].clamp_(min=0, max=w)
tensor[:, 3].clamp_(min=0, max=h)

View File

@@ -0,0 +1,99 @@
appdirs==1.4.3
argon2-cffi==20.1.0
async-generator==1.10
attrs==20.2.0
backcall==0.2.0
bleach==3.1.5
CacheControl==0.12.6
certifi==2020.6.20
cffi==1.14.2
chardet==3.0.4
click==7.1.2
colorama==0.4.3
contextlib2==0.6.0
cycler==0.10.0
datasets==1.0.0
decorator==4.4.2
defusedxml==0.6.0
dill==0.3.2
distlib==0.3.0
distro==1.4.0
entrypoints==0.3
filelock==3.0.12
future==0.18.2
html5lib==1.0.1
idna==2.8
ipaddr==2.2.0
ipykernel==5.3.4
ipython
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.17.2
Jinja2==2.11.2
joblib==0.16.0
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.7
jupyter-console==6.2.0
jupyter-core==4.6.3
jupyterlab-pygments==0.1.1
kiwisolver==1.2.0
lockfile==0.12.2
MarkupSafe==1.1.1
matplotlib==3.3.1
mistune==0.8.4
msgpack==0.6.2
nbclient==0.5.0
nbconvert==6.0.1
nbformat==5.0.7
nest-asyncio==1.4.0
notebook==6.1.4
numpy==1.19.2
opencv-python==4.4.0.42
packaging==20.3
pandas==1.1.2
pandocfilters==1.4.2
parso==0.7.1
pep517==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.2.0
progress==1.5
prometheus-client==0.8.0
prompt-toolkit==3.0.7
ptyprocess==0.6.0
pyaml==20.4.0
pyarrow==1.0.1
pycparser==2.20
Pygments==2.6.1
pyparsing==2.4.6
pyrsistent==0.16.0
python-dateutil==2.8.1
pytoml==0.1.21
pytz==2020.1
PyYAML==5.3.1
pyzmq==19.0.2
qtconsole==4.7.7
QtPy==1.9.0
regex==2020.7.14
requests==2.22.0
retrying==1.3.3
sacremoses==0.0.43
Send2Trash==1.5.0
sentencepiece==0.1.91
six==1.14.0
terminado==0.8.3
testpath==0.4.4
tokenizers==0.8.1rc2
torch==1.6.0
torchvision==0.7.0
tornado==6.0.4
tqdm==4.48.2
traitlets
transformers==3.5.1
urllib3==1.25.8
wcwidth==0.2.5
webencodings==0.5.1
wget==3.2
widgetsnbextension==3.5.1
xxhash==2.0.0

View File

@@ -0,0 +1,559 @@
"""
coding=utf-8
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :)
Adapted From Facebook Inc, Detectron2
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.import copy
"""
import copy
import fnmatch
import json
import os
import pickle as pkl
import shutil
import sys
import tarfile
import tempfile
from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
from hashlib import sha256
from io import BytesIO
from pathlib import Path
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
import cv2
import requests
import wget
from filelock import FileLock
from yaml import Loader, dump, load
try:
import torch
_torch_available = True
except ImportError:
_torch_available = False
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
default_cache_path = os.path.join(torch_cache_home, "transformers")
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1])
CONFIG = os.path.join(PATH, "config.yaml")
ATTRIBUTES = os.path.join(PATH, "attributes.txt")
OBJECTS = os.path.join(PATH, "objects.txt")
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
WEIGHTS_NAME = "pytorch_model.bin"
CONFIG_NAME = "config.yaml"
def load_labels(objs=OBJECTS, attrs=ATTRIBUTES):
vg_classes = []
with open(objs) as f:
for object in f.readlines():
vg_classes.append(object.split(",")[0].lower().strip())
vg_attrs = []
with open(attrs) as f:
for object in f.readlines():
vg_attrs.append(object.split(",")[0].lower().strip())
return vg_classes, vg_attrs
def load_checkpoint(ckp):
r = OrderedDict()
with open(ckp, "rb") as f:
ckp = pkl.load(f)["model"]
for k in copy.deepcopy(list(ckp.keys())):
v = ckp.pop(k)
if isinstance(v, np.ndarray):
v = torch.tensor(v)
else:
assert isinstance(v, torch.tensor), type(v)
r[k] = v
return r
class Config:
_pointer = {}
def __init__(self, dictionary: dict, name: str = "root", level=0):
self._name = name
self._level = level
d = {}
for k, v in dictionary.items():
if v is None:
raise ValueError()
k = copy.deepcopy(k)
v = copy.deepcopy(v)
if isinstance(v, dict):
v = Config(v, name=k, level=level + 1)
d[k] = v
setattr(self, k, v)
self._pointer = d
def __repr__(self):
return str(list((self._pointer.keys())))
def __setattr__(self, key, val):
self.__dict__[key] = val
self.__dict__[key.upper()] = val
levels = key.split(".")
last_level = len(levels) - 1
pointer = self._pointer
if len(levels) > 1:
for i, l in enumerate(levels):
if hasattr(self, l) and isinstance(getattr(self, l), Config):
setattr(getattr(self, l), ".".join(levels[i:]), val)
if l == last_level:
pointer[l] = val
else:
pointer = pointer[l]
def to_dict(self):
return self._pointer
def dump_yaml(self, data, file_name):
with open(f"{file_name}", "w") as stream:
dump(data, stream)
def dump_json(self, data, file_name):
with open(f"{file_name}", "w") as stream:
json.dump(data, stream)
@staticmethod
def load_yaml(config):
with open(config) as stream:
data = load(stream, Loader=Loader)
return data
def __str__(self):
t = " "
if self._name != "root":
r = f"{t * (self._level-1)}{self._name}:\n"
else:
r = ""
level = self._level
for i, (k, v) in enumerate(self._pointer.items()):
if isinstance(v, Config):
r += f"{t * (self._level)}{v}\n"
self._level += 1
else:
r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n"
self._level = level
return r[:-1]
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
return cls(config_dict)
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
)
# Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_file = Config.load_yaml(resolved_config_file)
except EnvironmentError:
msg = "Can't load config for"
raise EnvironmentError(msg)
if resolved_config_file == config_file:
print("loading configuration file from path")
else:
print("loading configuration file cache")
return Config.load_yaml(resolved_config_file), kwargs
# quick compare tensors
def compare(in_tensor):
out_tensor = torch.load("dump.pt", map_location=in_tensor.device)
n1 = in_tensor.numpy()
n2 = out_tensor.numpy()[0]
print(n1.shape, n1[0, 0, :5])
print(n2.shape, n2[0, 0, :5])
assert np.allclose(
n1, n2, rtol=0.01, atol=0.1
), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
raise Exception("tensors are all good")
# Hugging face functions below
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
def http_get(
url,
temp_file,
proxies=None,
resume_size=0,
user_agent=None,
):
ua = "python/{}".format(sys.version.split()[0])
if _torch_available:
ua += "; torch/{}".format(torch.__version__)
if isinstance(user_agent, dict):
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
headers = {"user-agent": ua}
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable
return
content_length = response.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc="Downloading",
)
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(
url,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=10,
resume_download=False,
user_agent=None,
local_files_only=False,
):
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
etag = None
if not local_files_only:
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
# etag is already None
pass
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# try to get the last downloaded one
if etag is None:
if os.path.exists(cache_path):
return cache_path
else:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise ValueError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
return None
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
return cache_path
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
with FileLock(lock_path):
# If the download just completed while the lock was activated.
if os.path.exists(cache_path) and not force_download:
# Even if returning early like here, the lock will be released.
return cache_path
if resume_download:
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager():
with open(incomplete_path, "a+b") as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
print(
"%s not found in cache or force_download set to True, downloading to %s",
url,
temp_file.name,
)
http_get(
url,
temp_file,
proxies=proxies,
resume_size=resume_size,
user_agent=user_agent,
)
os.replace(temp_file.name, cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path
def url_to_filename(url, etag=None):
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
if url.endswith(".h5"):
filename += ".h5"
return filename
def cached_path(
url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent=None,
extract_compressed_file=False,
force_extract=False,
local_files_only=False,
):
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
local_files_only=local_files_only,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
else:
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
return output_path_extracted
return output_path
def get_data(query, delim=","):
assert isinstance(query, str)
if os.path.isfile(query):
with open(query) as f:
data = eval(f.read())
else:
req = requests.get(query)
try:
data = requests.json()
except Exception:
data = req.content.decode()
assert data is not None, "could not connect"
try:
data = eval(data)
except Exception:
data = data.split("\n")
req.close()
return data
def get_image_from_url(url):
response = requests.get(url)
img = np.array(Image.open(BytesIO(response.content)))
return img
# to load legacy frcnn checkpoint from detectron
def load_frcnn_pkl_from_url(url):
fn = url.split("/")[-1]
if fn not in os.listdir(os.getcwd()):
wget.download(url)
with open(fn, "rb") as stream:
weights = pkl.load(stream)
model = weights.pop("model")
new = {}
for k, v in model.items():
new[k] = torch.from_numpy(v)
if "running_var" in k:
zero = torch.Tensor([0])
k2 = k.replace("running_var", "num_batches_tracked")
new[k2] = zero
return new
def get_demo_path():
print(f"{os.path.abspath(os.path.join(PATH, os.pardir))}/demo.ipynb")
def img_tensorize(im, input_format="RGB"):
assert isinstance(im, str)
if os.path.isfile(im):
img = cv2.imread(im)
else:
img = get_image_from_url(im)
assert img is not None, f"could not connect to: {im}"
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if input_format == "RGB":
img = img[:, :, ::-1]
return img
def chunk(images, batch=1):
return (images[i : i + batch] for i in range(0, len(images), batch))

View File

@@ -0,0 +1,499 @@
"""
coding=utf-8
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
Adapted From Facebook Inc, Detectron2
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.import copy
"""
import colorsys
import io
import matplotlib as mpl
import matplotlib.colors as mplc
import matplotlib.figure as mplfigure
import numpy as np
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg
import cv2
from utils import img_tensorize
_SMALL_OBJ = 1000
class SingleImageViz:
def __init__(
self,
img,
scale=1.2,
edgecolor="g",
alpha=0.5,
linestyle="-",
saveas="test_out.jpg",
rgb=True,
pynb=False,
id2obj=None,
id2attr=None,
pad=0.7,
):
"""
img: an RGB image of shape (H, W, 3).
"""
if isinstance(img, torch.Tensor):
img = img.numpy().astype("np.uint8")
if isinstance(img, str):
img = img_tensorize(img)
assert isinstance(img, np.ndarray)
width, height = img.shape[1], img.shape[0]
fig = mplfigure.Figure(frameon=False)
dpi = fig.get_dpi()
width_in = (width * scale + 1e-2) / dpi
height_in = (height * scale + 1e-2) / dpi
fig.set_size_inches(width_in, height_in)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
ax.set_xlim(0.0, width)
ax.set_ylim(height)
self.saveas = saveas
self.rgb = rgb
self.pynb = pynb
self.img = img
self.edgecolor = edgecolor
self.alpha = 0.5
self.linestyle = linestyle
self.font_size = int(np.sqrt(min(height, width)) * scale // 3)
self.width = width
self.height = height
self.scale = scale
self.fig = fig
self.ax = ax
self.pad = pad
self.id2obj = id2obj
self.id2attr = id2attr
self.canvas = FigureCanvasAgg(fig)
def add_box(self, box, color=None):
if color is None:
color = self.edgecolor
(x0, y0, x1, y1) = box
width = x1 - x0
height = y1 - y0
self.ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=color,
linewidth=self.font_size // 3,
alpha=self.alpha,
linestyle=self.linestyle,
)
)
def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None):
if len(boxes.shape) > 2:
boxes = boxes[0]
if len(obj_ids.shape) > 1:
obj_ids = obj_ids[0]
if len(obj_scores.shape) > 1:
obj_scores = obj_scores[0]
if len(attr_ids.shape) > 1:
attr_ids = attr_ids[0]
if len(attr_scores.shape) > 1:
attr_scores = attr_scores[0]
if isinstance(boxes, torch.Tensor):
boxes = boxes.numpy()
if isinstance(boxes, list):
boxes = np.array(boxes)
assert isinstance(boxes, np.ndarray)
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
sorted_idxs = np.argsort(-areas).tolist()
boxes = boxes[sorted_idxs] if boxes is not None else None
obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None
obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None
attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None
attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None
assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))]
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
if obj_ids is not None:
labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores)
for i in range(len(boxes)):
color = assigned_colors[i]
self.add_box(boxes[i], color)
self.draw_labels(labels[i], boxes[i], color)
def draw_labels(self, label, box, color):
x0, y0, x1, y1 = box
text_pos = (x0, y0)
instance_area = (y1 - y0) * (x1 - x0)
small = _SMALL_OBJ * self.scale
if instance_area < small or y1 - y0 < 40 * self.scale:
if y1 >= self.height - 5:
text_pos = (x1, y0)
else:
text_pos = (x0, y1)
height_ratio = (y1 - y0) / np.sqrt(self.height * self.width)
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
font_size *= 0.75 * self.font_size
self.draw_text(
text=label,
position=text_pos,
color=lighter_color,
)
def draw_text(
self,
text,
position,
color="g",
ha="left",
):
rotation = 0
font_size = self.font_size
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
bbox = {
"facecolor": "black",
"alpha": self.alpha,
"pad": self.pad,
"edgecolor": "none",
}
x, y = position
self.ax.text(
x,
y,
text,
size=font_size * self.scale,
family="sans-serif",
bbox=bbox,
verticalalignment="top",
horizontalalignment=ha,
color=color,
zorder=10,
rotation=rotation,
)
def save(self, saveas=None):
if saveas is None:
saveas = self.saveas
if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"):
cv2.imwrite(
saveas,
self._get_buffer()[:, :, ::-1],
)
else:
self.fig.savefig(saveas)
def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores):
labels = [self.id2obj[i] for i in classes]
attr_labels = [self.id2attr[i] for i in attr_classes]
labels = [
f"{label} {score:.2f} {attr} {attr_score:.2f}"
for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores)
]
return labels
def _create_text_labels(self, classes, scores):
labels = [self.id2obj[i] for i in classes]
if scores is not None:
if labels is None:
labels = ["{:.0f}%".format(s * 100) for s in scores]
else:
labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)]
return labels
def _random_color(self, maximum=255):
idx = np.random.randint(0, len(_COLORS))
ret = _COLORS[idx] * maximum
if not self.rgb:
ret = ret[::-1]
return ret
def _get_buffer(self):
if not self.pynb:
s, (width, height) = self.canvas.print_to_buffer()
if (width, height) != (self.width, self.height):
img = cv2.resize(self.img, (width, height))
else:
img = self.img
else:
buf = io.BytesIO() # works for cairo backend
self.canvas.print_rgba(buf)
width, height = self.width, self.height
s = buf.getvalue()
img = self.img
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
try:
import numexpr as ne # fuse them with numexpr
visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
except ImportError:
alpha = alpha.astype("float32") / 255.0
visualized_image = img * (1 - alpha) + rgb * alpha
return visualized_image.astype("uint8")
def _change_color_brightness(self, color, brightness_factor):
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
color = mplc.to_rgb(color)
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
return modified_color
# Color map
_COLORS = (
np.array(
[
0.000,
0.447,
0.741,
0.850,
0.325,
0.098,
0.929,
0.694,
0.125,
0.494,
0.184,
0.556,
0.466,
0.674,
0.188,
0.301,
0.745,
0.933,
0.635,
0.078,
0.184,
0.300,
0.300,
0.300,
0.600,
0.600,
0.600,
1.000,
0.000,
0.000,
1.000,
0.500,
0.000,
0.749,
0.749,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
1.000,
0.667,
0.000,
1.000,
0.333,
0.333,
0.000,
0.333,
0.667,
0.000,
0.333,
1.000,
0.000,
0.667,
0.333,
0.000,
0.667,
0.667,
0.000,
0.667,
1.000,
0.000,
1.000,
0.333,
0.000,
1.000,
0.667,
0.000,
1.000,
1.000,
0.000,
0.000,
0.333,
0.500,
0.000,
0.667,
0.500,
0.000,
1.000,
0.500,
0.333,
0.000,
0.500,
0.333,
0.333,
0.500,
0.333,
0.667,
0.500,
0.333,
1.000,
0.500,
0.667,
0.000,
0.500,
0.667,
0.333,
0.500,
0.667,
0.667,
0.500,
0.667,
1.000,
0.500,
1.000,
0.000,
0.500,
1.000,
0.333,
0.500,
1.000,
0.667,
0.500,
1.000,
1.000,
0.500,
0.000,
0.333,
1.000,
0.000,
0.667,
1.000,
0.000,
1.000,
1.000,
0.333,
0.000,
1.000,
0.333,
0.333,
1.000,
0.333,
0.667,
1.000,
0.333,
1.000,
1.000,
0.667,
0.000,
1.000,
0.667,
0.333,
1.000,
0.667,
0.667,
1.000,
0.667,
1.000,
1.000,
1.000,
0.000,
1.000,
1.000,
0.333,
1.000,
1.000,
0.667,
1.000,
0.333,
0.000,
0.000,
0.500,
0.000,
0.000,
0.667,
0.000,
0.000,
0.833,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
0.167,
0.000,
0.000,
0.333,
0.000,
0.000,
0.500,
0.000,
0.000,
0.667,
0.000,
0.000,
0.833,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
0.167,
0.000,
0.000,
0.333,
0.000,
0.000,
0.500,
0.000,
0.000,
0.667,
0.000,
0.000,
0.833,
0.000,
0.000,
1.000,
0.000,
0.000,
0.000,
0.143,
0.143,
0.143,
0.857,
0.857,
0.857,
1.000,
1.000,
1.000,
]
)
.astype(np.float32)
.reshape(-1, 3)
)

View File

@@ -0,0 +1,953 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-pruning Masked BERT on sequence classification on GLUE."""
import argparse
import glob
import json
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
from transformers import (
WEIGHTS_NAME,
AdamW,
BertConfig,
BertForSequenceClassification,
BertTokenizer,
get_linear_schedule_with_warmup,
)
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
from tensorboardX import SummaryWriter
logger = logging.getLogger(__name__)
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def schedule_threshold(
step: int,
total_step: int,
warmup_steps: int,
initial_threshold: float,
final_threshold: float,
initial_warmup: int,
final_warmup: int,
final_lambda: float,
):
if step <= initial_warmup * warmup_steps:
threshold = initial_threshold
elif step > (total_step - final_warmup * warmup_steps):
threshold = final_threshold
else:
spars_warmup_steps = initial_warmup * warmup_steps
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
regu_lambda = final_lambda * threshold / final_threshold
return threshold, regu_lambda
def regularization(model: nn.Module, mode: str):
regu, counter = 0, 0
for name, param in model.named_parameters():
if "mask_scores" in name:
if mode == "l1":
regu += torch.norm(torch.sigmoid(param), p=1) / param.numel()
elif mode == "l0":
regu += torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1)).sum() / param.numel()
else:
ValueError("Don't know this mode.")
counter += 1
return regu / counter
def train(args, train_dataset, model, tokenizer, teacher=None):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter(log_dir=args.output_dir)
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if "mask_score" in n and p.requires_grad],
"lr": args.mask_scores_learning_rate,
},
{
"params": [
p
for n, p in model.named_parameters()
if "mask_score" not in n and p.requires_grad and not any(nd in n for nd in no_decay)
],
"lr": args.learning_rate,
"weight_decay": args.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if "mask_score" not in n and p.requires_grad and any(nd in n for nd in no_decay)
],
"lr": args.learning_rate,
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
os.path.join(args.model_name_or_path, "scheduler.pt")
):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True,
)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
# Distillation
if teacher is not None:
logger.info(" Training with distillation")
global_step = 0
# Global TopK
if args.global_topk:
threshold_mem = None
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to global_step of last saved checkpoint from model path
try:
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
except ValueError:
global_step = 0
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(
epochs_trained,
int(args.num_train_epochs),
desc="Epoch",
disable=args.local_rank not in [-1, 0],
)
set_seed(args) # Added here for reproducibility
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train()
batch = tuple(t.to(args.device) for t in batch)
threshold, regu_lambda = schedule_threshold(
step=global_step,
total_step=t_total,
warmup_steps=args.warmup_steps,
final_threshold=args.final_threshold,
initial_threshold=args.initial_threshold,
final_warmup=args.final_warmup,
initial_warmup=args.initial_warmup,
final_lambda=args.final_lambda,
)
# Global TopK
if args.global_topk:
if threshold == 1.0:
threshold = -1e2 # Or an indefinitely low quantity
else:
if (threshold_mem is None) or (global_step % args.global_topk_frequency_compute == 0):
# Sort all the values to get the global topK
concat = torch.cat(
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
)
n = concat.numel()
kth = max(n - (int(n * threshold) + 1), 1)
threshold_mem = concat.kthvalue(kth).values.item()
threshold = threshold_mem
else:
threshold = threshold_mem
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"] = (
batch[2] if args.model_type in ["bert", "masked_bert", "xlnet", "albert"] else None
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
if "masked" in args.model_type:
inputs["threshold"] = threshold
outputs = model(**inputs)
loss, logits_stu = outputs # model outputs are always tuple in transformers (see doc)
# Distillation loss
if teacher is not None:
if "token_type_ids" not in inputs:
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
with torch.no_grad():
(logits_tea,) = teacher(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
loss_logits = (
F.kl_div(
input=F.log_softmax(logits_stu / args.temperature, dim=-1),
target=F.softmax(logits_tea / args.temperature, dim=-1),
reduction="batchmean",
)
* (args.temperature ** 2)
)
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
# Regularization
if args.regularization is not None:
regu_ = regularization(model=model, mode=args.regularization)
loss = loss + regu_lambda * regu_
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
len(epoch_iterator) <= args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
):
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("threshold", threshold, global_step)
for name, param in model.named_parameters():
if not param.requires_grad:
continue
tb_writer.add_scalar("parameter_mean/" + name, param.data.mean(), global_step)
tb_writer.add_scalar("parameter_std/" + name, param.data.std(), global_step)
tb_writer.add_scalar("parameter_min/" + name, param.data.min(), global_step)
tb_writer.add_scalar("parameter_max/" + name, param.data.max(), global_step)
tb_writer.add_scalar("grad_mean/" + name, param.grad.data.mean(), global_step)
tb_writer.add_scalar("grad_std/" + name, param.grad.data.std(), global_step)
if args.regularization is not None and "mask_scores" in name:
if args.regularization == "l1":
perc = (torch.sigmoid(param) > threshold).sum().item() / param.numel()
elif args.regularization == "l0":
perc = (torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1))).sum().item() / param.numel()
tb_writer.add_scalar("retained_weights_perc/" + name, perc, global_step)
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
logs = {}
if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer)
for key, value in results.items():
eval_key = "eval_{}".format(key)
logs[eval_key] = value
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
learning_rate_scalar = scheduler.get_lr()
logs["learning_rate"] = learning_rate_scalar[0]
if len(learning_rate_scalar) > 1:
for idx, lr in enumerate(learning_rate_scalar[1:]):
logs[f"learning_rate/{idx+1}"] = lr
logs["loss"] = loss_scalar
if teacher is not None:
logs["loss/distil"] = loss_logits.item()
if args.regularization is not None:
logs["loss/regularization"] = regu_.item()
if (teacher is not None) or (args.regularization is not None):
if (teacher is not None) and (args.regularization is not None):
logs["loss/instant_ce"] = (
loss.item()
- regu_lambda * logs["loss/regularization"]
- args.alpha_distil * logs["loss/distil"]
) / args.alpha_ce
elif teacher is not None:
logs["loss/instant_ce"] = (
loss.item() - args.alpha_distil * logs["loss/distil"]
) / args.alpha_ce
else:
logs["loss/instant_ce"] = loss.item() - regu_lambda * logs["loss/regularization"]
logging_loss = tr_loss
for key, value in logs.items():
tb_writer.add_scalar(key, value, global_step)
print(json.dumps({**logs, **{"step": global_step}}))
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
if args.local_rank in [-1, 0]:
tb_writer.close()
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
eval_outputs_dirs = (args.output_dir, args.output_dir + "/MM") if args.task_name == "mnli" else (args.output_dir,)
results = {}
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
# Global TopK
if args.global_topk:
threshold_mem = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"] = (
batch[2] if args.model_type in ["bert", "masked_bert", "xlnet", "albert"] else None
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
if "masked" in args.model_type:
inputs["threshold"] = args.final_threshold
if args.global_topk:
if threshold_mem is None:
concat = torch.cat(
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
)
n = concat.numel()
kth = max(n - (int(n * args.final_threshold) + 1), 1)
threshold_mem = concat.kthvalue(kth).values.item()
inputs["threshold"] = threshold_mem
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = inputs["labels"].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
if args.output_mode == "classification":
from scipy.special import softmax
probs = softmax(preds, axis=-1)
entropy = np.exp((-probs * np.log(probs)).sum(axis=-1).mean())
preds = np.argmax(preds, axis=1)
elif args.output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result)
if entropy is not None:
result["eval_avg_entropy"] = entropy
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return results
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
processor = processors[task]()
output_mode = output_modes[task]
# Load data features from cache or dataset file
cached_features_file = os.path.join(
args.data_dir,
"cached_{}_{}_{}_{}".format(
"dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
str(task),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels()
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = (
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
)
features = convert_examples_to_features(
examples,
tokenizer,
max_length=args.max_seq_length,
label_list=label_list,
output_mode=output_mode,
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
if output_mode == "classification":
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
elif output_mode == "regression":
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
return dataset
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
# Other parameters
parser.add_argument(
"--config_name",
default="",
type=str,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument(
"--evaluate_during_training",
action="store_true",
help="Run evaluation during training at each logging step.",
)
parser.add_argument(
"--do_lower_case",
action="store_true",
help="Set this flag if you are using an uncased model.",
)
parser.add_argument(
"--per_gpu_train_batch_size",
default=8,
type=int,
help="Batch size per GPU/CPU for training.",
)
parser.add_argument(
"--per_gpu_eval_batch_size",
default=8,
type=int,
help="Batch size per GPU/CPU for evaluation.",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
# Pruning parameters
parser.add_argument(
"--mask_scores_learning_rate",
default=1e-2,
type=float,
help="The Adam initial learning rate of the mask scores.",
)
parser.add_argument(
"--initial_threshold", default=1.0, type=float, help="Initial value of the threshold (for scheduling)."
)
parser.add_argument(
"--final_threshold", default=0.7, type=float, help="Final value of the threshold (for scheduling)."
)
parser.add_argument(
"--initial_warmup",
default=1,
type=int,
help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
"at its `initial_threshold` value (sparsity schedule).",
)
parser.add_argument(
"--final_warmup",
default=2,
type=int,
help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
"at its final_threshold value (sparsity schedule).",
)
parser.add_argument(
"--pruning_method",
default="topK",
type=str,
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
)
parser.add_argument(
"--mask_init",
default="constant",
type=str,
help="Initialization method for the mask scores. Choices: constant, uniform, kaiming.",
)
parser.add_argument(
"--mask_scale", default=0.0, type=float, help="Initialization parameter for the chosen initialization method."
)
parser.add_argument("--regularization", default=None, help="Add L0 or L1 regularization to the mask scores.")
parser.add_argument(
"--final_lambda",
default=0.0,
type=float,
help="Regularization intensity (used in conjunction with `regularization`.",
)
parser.add_argument("--global_topk", action="store_true", help="Global TopK on the Scores.")
parser.add_argument(
"--global_topk_frequency_compute",
default=25,
type=int,
help="Frequency at which we compute the TopK global threshold.",
)
# Distillation parameters (optional)
parser.add_argument(
"--teacher_type",
default=None,
type=str,
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
)
parser.add_argument(
"--teacher_name_or_path",
default=None,
type=str,
help="Path to the already fine-tuned teacher model. Only for distillation.",
)
parser.add_argument(
"--alpha_ce", default=0.5, type=float, help="Cross entropy loss linear weight. Only for distillation."
)
parser.add_argument(
"--alpha_distil", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
)
parser.add_argument(
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs",
default=3.0,
type=float,
help="Total number of training epochs to perform.",
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--overwrite_output_dir",
action="store_true",
help="Overwrite the content of the output directory",
)
parser.add_argument(
"--overwrite_cache",
action="store_true",
help="Overwrite the cached training and evaluation sets",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
# Regularization
if args.regularization == "null":
args.regularization = None
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed
set_seed(args)
# Prepare GLUE task
args.task_name = args.task_name.lower()
if args.task_name not in processors:
raise ValueError("Task not found: %s" % (args.task_name))
processor = processors[args.task_name]()
args.output_mode = output_modes[args.task_name]
label_list = processor.get_labels()
num_labels = len(label_list)
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None,
pruning_method=args.pruning_method,
mask_init=args.mask_init,
mask_scale=args.mask_scale,
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
do_lower_case=args.do_lower_case,
)
model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.teacher_type is not None:
assert args.teacher_name_or_path is not None
assert args.alpha_distil > 0.0
assert args.alpha_distil + args.alpha_ce > 0.0
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path)
teacher = teacher_model_class.from_pretrained(
args.teacher_name_or_path,
from_tf=False,
config=teacher_config,
cache_dir=args.cache_dir if args.cache_dir else None,
)
teacher.to(args.device)
else:
teacher = None
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
model.to(args.device)
logger.info("Training/evaluation parameters %s", args)
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
results.update(result)
return results
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
torch>=1.4.0
-e git+https://github.com/huggingface/transformers.git@352d5472b0c1dec0f420d606d16747d851b4bda8#egg=transformers
knockknock>=0.1.8.1
h5py>=2.10.0
numpy>=1.18.2
scipy>=1.4.1