Add QDQBert model and quantization examples of SQUAD task (#14066)
* clean up branch for add-qdqbert-model * README update for QAT example; update docstrings in modeling_qdqbert.py * Update qdqbert.rst * Update README.md * Update README.md * calibration data using traning set; QAT example runs in fp32 * re-use BERTtokenizer for qdqbert * Update docs/source/model_doc/qdqbert.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update docs/source/model_doc/qdqbert.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update docs/source/model_doc/qdqbert.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove qdqbert tokenizer * Update qdqbert.rst * update evaluate-hf-trt-qa.py * update configuration_qdqbert.py * update modeling_qdqbert.py: add copied statement; replace assert with ValueError * update copied from statement * add is_quantization_available; run make fix-copies * unittest add require_quantization * add backend dependency to qdqbert model * update README; update evaluate script; make style * lint * docs qdqbert update * circleci build_doc add pytorch-quantization for qdqbert * update README * update example readme with instructions to upgrade TensorRT to 8.2 * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * change quantization to pytorch_quantization for backend requirement * feed_forward_chunking not supported in QDQBert * make style * update model docstrings and comments in testing scripts * rename example to quantization-qdqbert; rename example scripts from qat to quant * Update src/transformers/models/qdqbert/modeling_qdqbert.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * rm experimental functions in quant_trainer * qa cleanup * make fix-copies for docs index.rst * fix doctree; use post_init() for qdqbert * fix early device assignment for qdqbert * fix CI:Model templates runner Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
37
examples/research_projects/quantization-qdqbert/Dockerfile
Normal file
37
examples/research_projects/quantization-qdqbert/Dockerfile
Normal file
@@ -0,0 +1,37 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
FROM nvcr.io/nvidia/pytorch:21.07-py3
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="transformers"
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get install sudo
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir --ignore-installed ruamel.yaml \
|
||||
mkl \
|
||||
absl-py \
|
||||
yamlpy \
|
||||
tensorboardX
|
||||
RUN python3 -m pip install --no-cache-dir \
|
||||
pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com
|
||||
|
||||
WORKDIR /workspace
|
||||
COPY . transformers/
|
||||
RUN cd transformers/ && \
|
||||
python3 -m pip install --no-cache-dir .
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir datasets \
|
||||
accelerate
|
||||
197
examples/research_projects/quantization-qdqbert/README.md
Normal file
197
examples/research_projects/quantization-qdqbert/README.md
Normal file
@@ -0,0 +1,197 @@
|
||||
<!---
|
||||
Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Huggingface QDQBERT Quantization Example
|
||||
|
||||
The QDQBERT model adds fake quantization (pair of QuantizeLinear/DequantizeLinear ops) to:
|
||||
* linear layer inputs and weights
|
||||
* matmul inputs
|
||||
* residual add inputs
|
||||
|
||||
In this example, we use QDQBERT model to do quantization on SQuAD task, including Quantization Aware Training (QAT), Post Training Quantization (PTQ) and inferencing using TensorRT.
|
||||
|
||||
Required:
|
||||
- [pytorch-quantization toolkit](https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization)
|
||||
- [TensorRT >= 8.2](https://developer.nvidia.com/tensorrt)
|
||||
- PyTorch >= 1.10.0
|
||||
|
||||
## Setup the environment with Dockerfile
|
||||
|
||||
Under the directory of `transformers/`, build the docker image:
|
||||
```
|
||||
docker build . -f examples/research_projects/quantization-qdqbert/Dockerfile -t bert_quantization:latest
|
||||
```
|
||||
|
||||
Run the docker:
|
||||
```
|
||||
docker run --gpus all --privileged --rm -it --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 bert_quantization:latest
|
||||
```
|
||||
|
||||
*Note that the current NGC pytorch container (pytorch:21.07-py3) has TensorRT 8.0 which doesn't meet the requiremnt of TensorRT >= 8.2. One can either update the Dockerfile with the latest [NGC pytorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) once it supports TensorRT 8.2, or manually download and install [TensorRT >= 8.2](https://developer.nvidia.com/nvidia-tensorrt-download) in the container.*
|
||||
|
||||
|
||||
In the container:
|
||||
```
|
||||
cd transformers/examples/research_projects/quantization-qdqbert/
|
||||
```
|
||||
|
||||
## Quantization Aware Training (QAT)
|
||||
|
||||
Calibrate the pretrained model and finetune with quantization awared:
|
||||
|
||||
```
|
||||
python3 run_quant_qa.py \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--dataset_name squad \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--output_dir calib/bert-base-uncased \
|
||||
--do_calib \
|
||||
--calibrator percentile \
|
||||
--percentile 99.99
|
||||
```
|
||||
|
||||
```
|
||||
python3 run_quant_qa.py \
|
||||
--model_name_or_path calib/bert-base-uncased \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size 12 \
|
||||
--learning_rate 4e-5 \
|
||||
--num_train_epochs 2 \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--output_dir finetuned_int8/bert-base-uncased \
|
||||
--tokenizer_name bert-base-uncased \
|
||||
--save_steps 0
|
||||
```
|
||||
|
||||
### Export QAT model to ONNX
|
||||
|
||||
To export the QAT model finetuned above:
|
||||
|
||||
```
|
||||
python3 run_quant_qa.py \
|
||||
--model_name_or_path finetuned_int8/bert-base-uncased \
|
||||
--output_dir ./ \
|
||||
--save_onnx \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--dataset_name squad \
|
||||
--tokenizer_name bert-base-uncased
|
||||
```
|
||||
|
||||
Use `--recalibrate-weights` to calibrate the weight ranges according to the quantizer axis. Use `--quant-per-tensor` for per tensor quantization (default is per channel).
|
||||
Recalibrating will affect the accuracy of the model, but the change should be minimal (< 0.5 F1).
|
||||
|
||||
### Benchmark the INT8 QAT ONNX model inference with TensorRT using dummy input
|
||||
|
||||
```
|
||||
trtexec --onnx=model.onnx --explicitBatch --workspace=16384 --int8 --shapes=input_ids:64x128,attention_mask:64x128,token_type_ids:64x128 --verbose
|
||||
```
|
||||
|
||||
### Evaluate the INT8 QAT ONNX model inference with TensorRT
|
||||
|
||||
```
|
||||
python3 evaluate-hf-trt-qa.py \
|
||||
--onnx_model_path=./model.onnx \
|
||||
--output_dir ./ \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--dataset_name squad \
|
||||
--tokenizer_name bert-base-uncased \
|
||||
--int8 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
## Fine-tuning of FP32 model for comparison
|
||||
|
||||
Finetune a fp32 precision model with [transformers/examples/pytorch/question-answering/](../../pytorch/question-answering/):
|
||||
|
||||
```
|
||||
python3 ../../pytorch/question-answering/run_qa.py \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--dataset_name squad \
|
||||
--per_device_train_batch_size 12 \
|
||||
--learning_rate 3e-5 \
|
||||
--num_train_epochs 2 \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--output_dir ./finetuned_fp32/bert-base-uncased \
|
||||
--save_steps 0 \
|
||||
--do_train \
|
||||
--do_eval
|
||||
```
|
||||
|
||||
## Post Training Quantization (PTQ)
|
||||
|
||||
### PTQ by calibrating and evaluating the finetuned FP32 model above:
|
||||
|
||||
```
|
||||
python3 run_quant_qa.py \
|
||||
--model_name_or_path ./finetuned_fp32/bert-base-uncased \
|
||||
--dataset_name squad \
|
||||
--calibrator percentile \
|
||||
--percentile 99.99 \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--output_dir ./calib/bert-base-uncased \
|
||||
--save_steps 0 \
|
||||
--do_calib \
|
||||
--do_eval
|
||||
```
|
||||
|
||||
### Export the INT8 PTQ model to ONNX
|
||||
|
||||
```
|
||||
python3 run_quant_qa.py \
|
||||
--model_name_or_path ./calib/bert-base-uncased \
|
||||
--output_dir ./ \
|
||||
--save_onnx \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--dataset_name squad \
|
||||
--tokenizer_name bert-base-uncased
|
||||
```
|
||||
|
||||
### Evaluate the INT8 PTQ ONNX model inference with TensorRT
|
||||
|
||||
```
|
||||
python3 evaluate-hf-trt-qa.py \
|
||||
--onnx_model_path=./model.onnx \
|
||||
--output_dir ./ \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--max_seq_length 128 \
|
||||
--doc_stride 32 \
|
||||
--dataset_name squad \
|
||||
--tokenizer_name bert-base-uncased \
|
||||
--int8 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Quantization options
|
||||
|
||||
Some useful options to support different implementations and optimizations. These should be specified for both calibration and finetuning.
|
||||
|
||||
|argument|description|
|
||||
|--------|-----------|
|
||||
|`--quant-per-tensor`| quantize weights with one quantization range per tensor |
|
||||
|`--fuse-qkv` | use a single range (the max) for quantizing QKV weights and output activations |
|
||||
|`--clip-gelu N` | clip the output of GELU to a maximum of N when quantizing (e.g. 10) |
|
||||
|`--disable-dropout` | disable dropout for consistent activation ranges |
|
||||
456
examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
Executable file
456
examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
Executable file
@@ -0,0 +1,456 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import timeit
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from absl import logging as absl_logging
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pycuda.autoinit # noqa: F401
|
||||
import pycuda.driver as cuda
|
||||
import tensorrt as trt
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from transformers import AutoTokenizer, EvalPrediction, default_data_collator, set_seed
|
||||
from transformers.trainer_pt_utils import nested_concat, nested_truncate
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
absl_logger = absl_logging.get_absl_logger()
|
||||
absl_logger.setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--onnx_model_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to ONNX model: ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--version_2_with_negative",
|
||||
action="store_true",
|
||||
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--null_score_diff_threshold",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_stride",
|
||||
default=128,
|
||||
type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||
)
|
||||
|
||||
parser.add_argument("--per_device_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.")
|
||||
|
||||
parser.add_argument(
|
||||
"--n_best_size",
|
||||
default=20,
|
||||
type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_length",
|
||||
default=30,
|
||||
type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.",
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--int8",
|
||||
action="store_true",
|
||||
help="Whether to use INT8",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
args.eval_batch_size = args.per_device_eval_batch_size
|
||||
|
||||
INPUT_SHAPE = (args.eval_batch_size, args.max_seq_length)
|
||||
|
||||
# TRT Engine properties
|
||||
STRICT_TYPES = True
|
||||
|
||||
engine_name = "temp_engine/bert-fp32.engine"
|
||||
if args.fp16:
|
||||
engine_name = "temp_engine/bert-fp16.engine"
|
||||
if args.int8:
|
||||
engine_name = "temp_engine/bert-int8.engine"
|
||||
|
||||
# import ONNX file
|
||||
if not os.path.exists("temp_engine"):
|
||||
os.makedirs("temp_engine")
|
||||
|
||||
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(
|
||||
network, TRT_LOGGER
|
||||
) as parser:
|
||||
with open(args.onnx_model_path, "rb") as model:
|
||||
if not parser.parse(model.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
|
||||
# Query input names and shapes from parsed TensorRT network
|
||||
network_inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
input_names = [_input.name for _input in network_inputs] # ex: ["actual_input1"]
|
||||
|
||||
with builder.create_builder_config() as config:
|
||||
config.max_workspace_size = 1 << 50
|
||||
if STRICT_TYPES:
|
||||
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
|
||||
if args.fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
if args.int8:
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
profile = builder.create_optimization_profile()
|
||||
config.add_optimization_profile(profile)
|
||||
for i in range(len(input_names)):
|
||||
profile.set_shape(input_names[i], INPUT_SHAPE, INPUT_SHAPE, INPUT_SHAPE)
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
# serialize_engine and store in file (can be directly loaded and deserialized):
|
||||
with open(engine_name, "wb") as f:
|
||||
f.write(engine.serialize())
|
||||
|
||||
|
||||
# run inference with TRT
|
||||
def model_infer(inputs, context, d_inputs, h_output0, h_output1, d_output0, d_output1, stream):
|
||||
input_ids = np.asarray(inputs["input_ids"], dtype=np.int32)
|
||||
attention_mask = np.asarray(inputs["attention_mask"], dtype=np.int32)
|
||||
token_type_ids = np.asarray(inputs["token_type_ids"], dtype=np.int32)
|
||||
|
||||
# Copy inputs
|
||||
cuda.memcpy_htod_async(d_inputs[0], input_ids.ravel(), stream)
|
||||
cuda.memcpy_htod_async(d_inputs[1], attention_mask.ravel(), stream)
|
||||
cuda.memcpy_htod_async(d_inputs[2], token_type_ids.ravel(), stream)
|
||||
# start time
|
||||
start_time = time.time()
|
||||
# Run inference
|
||||
context.execute_async(
|
||||
bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output0), int(d_output1)], stream_handle=stream.handle
|
||||
)
|
||||
# Transfer predictions back from GPU
|
||||
cuda.memcpy_dtoh_async(h_output0, d_output0, stream)
|
||||
cuda.memcpy_dtoh_async(h_output1, d_output1, stream)
|
||||
# Synchronize the stream and take time
|
||||
stream.synchronize()
|
||||
# end time
|
||||
end_time = time.time()
|
||||
infer_time = end_time - start_time
|
||||
outputs = (h_output0, h_output1)
|
||||
# print(outputs)
|
||||
return outputs, infer_time
|
||||
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
accelerator = Accelerator()
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
# Setup logging, we only want one process per machine to log things on the screen.
|
||||
# accelerator.is_local_main_process is only True for one process per machine.
|
||||
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||
else:
|
||||
raise ValueError("Evaluation requires a dataset name")
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# Preprocessing is slighlty different for training and evaluation.
|
||||
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
|
||||
question_column_name = "question" if "question" in column_names else column_names[0]
|
||||
context_column_name = "context" if "context" in column_names else column_names[1]
|
||||
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
||||
|
||||
# Padding side determines if we do (question|context) or (context|question).
|
||||
pad_on_right = tokenizer.padding_side == "right"
|
||||
|
||||
if args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
|
||||
max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
|
||||
# Validation preprocessing
|
||||
def prepare_validation_features(examples):
|
||||
# Some of the questions have lots of whitespace on the left, which is not useful and will make the
|
||||
# truncation of the context fail (the tokenized question will take a lots of space). So we remove that
|
||||
# left whitespace
|
||||
examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
|
||||
|
||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||
# in one example possible giving several features when a context is long, each of those features having a
|
||||
# context that overlaps a bit the context of the previous feature.
|
||||
tokenized_examples = tokenizer(
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=max_seq_length,
|
||||
stride=args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# Since one example might give us several features if it has a long context, we need a map from a feature to
|
||||
# its corresponding example. This key gives us just that.
|
||||
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
||||
|
||||
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
|
||||
# corresponding example_id and we will store the offset mappings.
|
||||
tokenized_examples["example_id"] = []
|
||||
|
||||
for i in range(len(tokenized_examples["input_ids"])):
|
||||
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
|
||||
sequence_ids = tokenized_examples.sequence_ids(i)
|
||||
context_index = 1 if pad_on_right else 0
|
||||
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = sample_mapping[i]
|
||||
tokenized_examples["example_id"].append(examples["id"][sample_index])
|
||||
|
||||
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
|
||||
# position is part of the context or not.
|
||||
tokenized_examples["offset_mapping"][i] = [
|
||||
(o if sequence_ids[k] == context_index else None)
|
||||
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
||||
]
|
||||
|
||||
return tokenized_examples
|
||||
|
||||
|
||||
eval_examples = raw_datasets["validation"]
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc="Running tokenizer on validation dataset",
|
||||
)
|
||||
|
||||
data_collator = default_data_collator
|
||||
|
||||
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
|
||||
# Post-processing:
|
||||
def post_processing_function(examples, features, predictions, stage="eval"):
|
||||
# Post-processing: we match the start logits and end logits to answers in the original context.
|
||||
predictions = postprocess_qa_predictions(
|
||||
examples=examples,
|
||||
features=features,
|
||||
predictions=predictions,
|
||||
version_2_with_negative=args.version_2_with_negative,
|
||||
n_best_size=args.n_best_size,
|
||||
max_answer_length=args.max_answer_length,
|
||||
null_score_diff_threshold=args.null_score_diff_threshold,
|
||||
output_dir=args.output_dir,
|
||||
prefix=stage,
|
||||
)
|
||||
# Format the result to the format the metric expects.
|
||||
if args.version_2_with_negative:
|
||||
formatted_predictions = [
|
||||
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
|
||||
]
|
||||
else:
|
||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
|
||||
metric = load_metric("squad_v2" if args.version_2_with_negative else "squad")
|
||||
|
||||
# Evaluation!
|
||||
logger.info("Loading ONNX model %s for evaluation", args.onnx_model_path)
|
||||
with open(engine_name, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine(
|
||||
f.read()
|
||||
) as engine, engine.create_execution_context() as context:
|
||||
|
||||
# setup for TRT inferrence
|
||||
for i in range(len(input_names)):
|
||||
context.set_binding_shape(i, INPUT_SHAPE)
|
||||
assert context.all_binding_shapes_specified
|
||||
|
||||
def binding_nbytes(binding):
|
||||
return trt.volume(engine.get_binding_shape(binding)) * engine.get_binding_dtype(binding).itemsize
|
||||
|
||||
# Allocate device memory for inputs and outputs.
|
||||
d_inputs = [cuda.mem_alloc(binding_nbytes(binding)) for binding in engine if engine.binding_is_input(binding)]
|
||||
|
||||
# Allocate output buffer
|
||||
h_output0 = cuda.pagelocked_empty(tuple(context.get_binding_shape(3)), dtype=np.float32)
|
||||
h_output1 = cuda.pagelocked_empty(tuple(context.get_binding_shape(4)), dtype=np.float32)
|
||||
d_output0 = cuda.mem_alloc(h_output0.nbytes)
|
||||
d_output1 = cuda.mem_alloc(h_output1.nbytes)
|
||||
|
||||
# Create a stream in which to copy inputs/outputs and run inference.
|
||||
stream = cuda.Stream()
|
||||
|
||||
# Evaluation
|
||||
logger.info("***** Running Evaluation *****")
|
||||
logger.info(f" Num examples = {len(eval_dataset)}")
|
||||
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
|
||||
|
||||
total_time = 0.0
|
||||
niter = 0
|
||||
start_time = timeit.default_timer()
|
||||
|
||||
all_preds = None
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
|
||||
outputs, infer_time = model_infer(batch, context, d_inputs, h_output0, h_output1, d_output0, d_output1, stream)
|
||||
total_time += infer_time
|
||||
niter += 1
|
||||
|
||||
start_logits, end_logits = outputs
|
||||
start_logits = torch.tensor(start_logits)
|
||||
end_logits = torch.tensor(end_logits)
|
||||
|
||||
# necessary to pad predictions and labels for being gathered
|
||||
start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
|
||||
end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
|
||||
|
||||
logits = (accelerator.gather(start_logits).cpu().numpy(), accelerator.gather(end_logits).cpu().numpy())
|
||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
||||
|
||||
if all_preds is not None:
|
||||
all_preds = nested_truncate(all_preds, len(eval_dataset))
|
||||
|
||||
evalTime = timeit.default_timer() - start_time
|
||||
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(eval_dataset))
|
||||
# Inference time from TRT
|
||||
logger.info("Average Inference Time = {:.3f} ms".format(total_time * 1000 / niter))
|
||||
logger.info("Total Inference Time = {:.3f} ms".format(total_time * 1000))
|
||||
logger.info("Total Number of Inference = %d", niter)
|
||||
|
||||
prediction = post_processing_function(eval_examples, eval_dataset, all_preds)
|
||||
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
logger.info(f"Evaluation metrics: {eval_metric}")
|
||||
303
examples/research_projects/quantization-qdqbert/quant_trainer.py
Executable file
303
examples/research_projects/quantization-qdqbert/quant_trainer.py
Executable file
@@ -0,0 +1,303 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper functions for training models with pytorch-quantization"""
|
||||
import logging
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_quantization
|
||||
import pytorch_quantization.nn as quant_nn
|
||||
from pytorch_quantization import calib
|
||||
from pytorch_quantization.tensor_quant import QuantDescriptor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
name_width = 50 # max width of layer names
|
||||
qname_width = 70 # max width of quantizer names
|
||||
|
||||
# ========================================== Quant Trainer API ==========================================
|
||||
|
||||
|
||||
def add_arguments(parser):
|
||||
"""Add arguments to parser for functions defined in quant_trainer."""
|
||||
|
||||
group = parser.add_argument_group("quant_trainer arguments")
|
||||
group.add_argument("--wprec", type=int, default=8, help="weight precision")
|
||||
group.add_argument("--aprec", type=int, default=8, help="activation precision")
|
||||
group.add_argument("--quant-per-tensor", action="store_true", help="per tensor weight scaling")
|
||||
group.add_argument("--quant-disable", action="store_true", help="disable all quantizers")
|
||||
group.add_argument("--quant-disable-embeddings", action="store_true", help="disable all embeddings quantizers")
|
||||
group.add_argument("--quant-disable-keyword", type=str, nargs="+", help="disable quantizers by keyword")
|
||||
group.add_argument("--quant-disable-layer-module", type=str, help="disable quantizers by keyword under layer.\d+.")
|
||||
group.add_argument("--quant-enable-layer-module", type=str, help="enable quantizers by keyword under layer.\d+.")
|
||||
group.add_argument("--calibrator", default="max", help="which quantization range calibrator to use")
|
||||
group.add_argument("--percentile", default=None, type=float, help="percentile for PercentileCalibrator")
|
||||
group.add_argument("--fuse-qkv", action="store_true", help="use the same scale factor for qkv")
|
||||
group.add_argument("--clip-gelu", metavar="N", type=float, help="clip gelu output maximum value to N")
|
||||
group.add_argument(
|
||||
"--recalibrate-weights",
|
||||
action="store_true",
|
||||
help="recalibrate weight amaxes by taking the max of the weights."
|
||||
" amaxes will be computed with the current quantization granularity (axis).",
|
||||
)
|
||||
|
||||
|
||||
def set_default_quantizers(args):
|
||||
"""Set default quantizers before creating the model."""
|
||||
|
||||
if args.calibrator == "max":
|
||||
calib_method = "max"
|
||||
elif args.calibrator == "percentile":
|
||||
if args.percentile is None:
|
||||
raise ValueError("Specify --percentile when using percentile calibrator")
|
||||
calib_method = "histogram"
|
||||
elif args.calibrator == "mse":
|
||||
calib_method = "histogram"
|
||||
else:
|
||||
raise ValueError(f"Invalid calibrator {args.calibrator}")
|
||||
|
||||
input_desc = QuantDescriptor(num_bits=args.aprec, calib_method=calib_method)
|
||||
weight_desc = QuantDescriptor(num_bits=args.wprec, axis=(None if args.quant_per_tensor else (0,)))
|
||||
quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)
|
||||
quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)
|
||||
|
||||
|
||||
def configure_model(model, args, calib=False, eval=False):
|
||||
"""Function called before the training loop."""
|
||||
|
||||
logger.info("Configuring Model for Quantization")
|
||||
logger.info(f"using quantization package {pytorch_quantization.__file__}")
|
||||
|
||||
if not calib:
|
||||
if args.quant_disable_embeddings:
|
||||
set_quantizer_by_name(model, ["embeddings"], which="weight", _disabled=True)
|
||||
|
||||
if args.quant_disable:
|
||||
set_quantizer_by_name(model, [""], _disabled=True)
|
||||
|
||||
if args.quant_disable_keyword:
|
||||
set_quantizer_by_name(model, args.quant_disable_keyword, _disabled=True)
|
||||
|
||||
if args.quant_disable_layer_module:
|
||||
set_quantizer_by_name(model, ["layer.\d+." + args.quant_disable_layer_module], _disabled=True)
|
||||
|
||||
if args.quant_enable_layer_module:
|
||||
set_quantizer_by_name(model, ["layer.\d+." + args.quant_enable_layer_module], _disabled=False)
|
||||
|
||||
if args.recalibrate_weights:
|
||||
recalibrate_weights(model)
|
||||
|
||||
if args.fuse_qkv:
|
||||
fuse_qkv(model, args)
|
||||
|
||||
if args.clip_gelu:
|
||||
clip_gelu(model, args.clip_gelu)
|
||||
|
||||
# if args.local_rank in [-1, 0] and not calib:
|
||||
print_quant_summary(model)
|
||||
|
||||
|
||||
def enable_calibration(model):
|
||||
"""Enable calibration of all *_input_quantizer modules in model."""
|
||||
|
||||
logger.info("Enabling Calibration")
|
||||
for name, module in model.named_modules():
|
||||
if name.endswith("_quantizer"):
|
||||
if module._calibrator is not None:
|
||||
module.disable_quant()
|
||||
module.enable_calib()
|
||||
else:
|
||||
module.disable()
|
||||
logger.info(f"{name:80}: {module}")
|
||||
|
||||
|
||||
def finish_calibration(model, args):
|
||||
"""Disable calibration and load amax for all "*_input_quantizer modules in model."""
|
||||
|
||||
logger.info("Loading calibrated amax")
|
||||
for name, module in model.named_modules():
|
||||
if name.endswith("_quantizer"):
|
||||
if module._calibrator is not None:
|
||||
if isinstance(module._calibrator, calib.MaxCalibrator):
|
||||
module.load_calib_amax()
|
||||
else:
|
||||
module.load_calib_amax("percentile", percentile=args.percentile)
|
||||
module.enable_quant()
|
||||
module.disable_calib()
|
||||
else:
|
||||
module.enable()
|
||||
model.cuda()
|
||||
print_quant_summary(model)
|
||||
|
||||
|
||||
# ========================================== Helper Function ==========================================
|
||||
|
||||
|
||||
def fuse_qkv(model, args):
|
||||
"""Adjust quantization ranges to match an implementation where the QKV projections are implemented with a single GEMM.
|
||||
Force the weight and output scale factors to match by taking the max of (Q,K,V).
|
||||
"""
|
||||
|
||||
def fuse3(qq, qk, qv):
|
||||
for mod in [qq, qk, qv]:
|
||||
if not hasattr(mod, "_amax"):
|
||||
print(" WARNING: NO AMAX BUFFER")
|
||||
return
|
||||
q = qq._amax.detach().item()
|
||||
k = qk._amax.detach().item()
|
||||
v = qv._amax.detach().item()
|
||||
|
||||
amax = max(q, k, v)
|
||||
qq._amax.fill_(amax)
|
||||
qk._amax.fill_(amax)
|
||||
qv._amax.fill_(amax)
|
||||
logger.info(f" q={q:5.2f} k={k:5.2f} v={v:5.2f} -> {amax:5.2f}")
|
||||
|
||||
for name, mod in model.named_modules():
|
||||
if name.endswith(".attention.self"):
|
||||
logger.info(f"FUSE_QKV: {name:{name_width}}")
|
||||
fuse3(mod.matmul_q_input_quantizer, mod.matmul_k_input_quantizer, mod.matmul_v_input_quantizer)
|
||||
if args.quant_per_tensor:
|
||||
fuse3(mod.query._weight_quantizer, mod.key._weight_quantizer, mod.value._weight_quantizer)
|
||||
|
||||
|
||||
def clip_gelu(model, maxval):
|
||||
"""Clip activations generated by GELU to maxval when quantized.
|
||||
Implemented by adjusting the amax of the following input_quantizer.
|
||||
"""
|
||||
|
||||
for name, mod in model.named_modules():
|
||||
if name.endswith(".output.dense") and not name.endswith("attention.output.dense"):
|
||||
amax_init = mod._input_quantizer._amax.data.detach().item()
|
||||
mod._input_quantizer._amax.data.detach().clamp_(max=maxval)
|
||||
amax = mod._input_quantizer._amax.data.detach().item()
|
||||
logger.info(f"CLIP_GELU: {name:{name_width}} amax: {amax_init:5.2f} -> {amax:5.2f}")
|
||||
|
||||
|
||||
def expand_amax(model):
|
||||
"""Expand per-tensor amax to be per channel, where each channel is assigned the per-tensor amax."""
|
||||
|
||||
for name, mod in model.named_modules():
|
||||
if hasattr(mod, "_weight_quantizer") and mod._weight_quantizer.axis is not None:
|
||||
k = mod.weight.shape[0]
|
||||
amax = mod._weight_quantizer._amax.detach()
|
||||
mod._weight_quantizer._amax = torch.ones(k, dtype=amax.dtype, device=amax.device) * amax
|
||||
print(f"expanding {name} {amax} -> {mod._weight_quantizer._amax}")
|
||||
|
||||
|
||||
def recalibrate_weights(model):
|
||||
"""Performs max calibration on the weights and updates amax."""
|
||||
|
||||
for name, mod in model.named_modules():
|
||||
if hasattr(mod, "_weight_quantizer"):
|
||||
if not hasattr(mod.weight_quantizer, "_amax"):
|
||||
print("RECALIB: {name:{name_width}} WARNING: NO AMAX BUFFER")
|
||||
continue
|
||||
|
||||
# determine which axes to reduce across
|
||||
# e.g. a 4D tensor quantized per axis 0 should reduce over (1,2,3)
|
||||
axis_set = set() if mod._weight_quantizer.axis is None else set(mod._weight_quantizer.axis)
|
||||
reduce_axis = set(range(len(mod.weight.size()))) - axis_set
|
||||
amax = pytorch_quantization.utils.reduce_amax(mod.weight, axis=reduce_axis, keepdims=True).detach()
|
||||
logger.info(f"RECALIB: {name:{name_width}} {mod._weight_quantizer._amax.flatten()} -> {amax.flatten()}")
|
||||
mod._weight_quantizer._amax = amax
|
||||
|
||||
|
||||
def print_model_summary(model, name_width=25, line_width=180, ignore=None):
|
||||
"""Print model quantization configuration."""
|
||||
|
||||
if ignore is None:
|
||||
ignore = []
|
||||
elif not isinstance(ignore, list):
|
||||
ignore = [ignore]
|
||||
|
||||
name_width = 0
|
||||
for name, mod in model.named_modules():
|
||||
if not hasattr(mod, "weight"):
|
||||
continue
|
||||
name_width = max(name_width, len(name))
|
||||
|
||||
for name, mod in model.named_modules():
|
||||
input_q = getattr(mod, "_input_quantizer", None)
|
||||
weight_q = getattr(mod, "_weight_quantizer", None)
|
||||
if not hasattr(mod, "weight"):
|
||||
continue
|
||||
if type(mod) in ignore:
|
||||
continue
|
||||
if [True for s in ignore if type(s) is str and s in name]:
|
||||
continue
|
||||
act_str = f"Act:{input_q.extra_repr()}"
|
||||
wgt_str = f"Wgt:{weight_q.extra_repr()}"
|
||||
s = f"{name:{name_width}} {act_str} {wgt_str}"
|
||||
if len(s) <= line_width:
|
||||
logger.info(s)
|
||||
else:
|
||||
logger.info(f"{name:{name_width}} {act_str}")
|
||||
logger.info(f'{" ":{name_width}} {wgt_str}')
|
||||
|
||||
|
||||
def print_quant_summary(model):
|
||||
"""Print summary of all quantizer modules in the model."""
|
||||
|
||||
count = 0
|
||||
for name, mod in model.named_modules():
|
||||
if isinstance(mod, pytorch_quantization.nn.TensorQuantizer):
|
||||
print(f"{name:80} {mod}")
|
||||
count += 1
|
||||
print(f"{count} TensorQuantizers found in model")
|
||||
|
||||
|
||||
def set_quantizer(name, mod, quantizer, k, v):
|
||||
"""Set attributes for mod.quantizer."""
|
||||
|
||||
quantizer_mod = getattr(mod, quantizer, None)
|
||||
if quantizer_mod is not None:
|
||||
assert hasattr(quantizer_mod, k)
|
||||
setattr(quantizer_mod, k, v)
|
||||
else:
|
||||
logger.warn(f"{name} has no {quantizer}")
|
||||
|
||||
|
||||
def set_quantizers(name, mod, which="both", **kwargs):
|
||||
"""Set quantizer attributes for mod."""
|
||||
|
||||
s = f"Warning: changing {which} quantizers of {name:{qname_width}}"
|
||||
for k, v in kwargs.items():
|
||||
s += f" {k}={v}"
|
||||
if which in ["input", "both"]:
|
||||
set_quantizer(name, mod, "_input_quantizer", k, v)
|
||||
if which in ["weight", "both"]:
|
||||
set_quantizer(name, mod, "_weight_quantizer", k, v)
|
||||
logger.info(s)
|
||||
|
||||
|
||||
def set_quantizer_by_name(model, names, **kwargs):
|
||||
"""Set quantizer attributes for layers where name contains a substring in names."""
|
||||
|
||||
for name, mod in model.named_modules():
|
||||
if hasattr(mod, "_input_quantizer") or hasattr(mod, "_weight_quantizer"):
|
||||
for n in names:
|
||||
if re.search(n, name):
|
||||
set_quantizers(name, mod, **kwargs)
|
||||
elif name.endswith("_quantizer"):
|
||||
for n in names:
|
||||
if re.search(n, name):
|
||||
s = f"Warning: changing {name:{name_width}}"
|
||||
for k, v in kwargs.items():
|
||||
s += f" {k}={v}"
|
||||
setattr(mod, k, v)
|
||||
logger.info(s)
|
||||
668
examples/research_projects/quantization-qdqbert/run_quant_qa.py
Executable file
668
examples/research_projects/quantization-qdqbert/run_quant_qa.py
Executable file
@@ -0,0 +1,668 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
# Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for question answering.
|
||||
"""
|
||||
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset, load_metric
|
||||
|
||||
import quant_trainer
|
||||
import transformers
|
||||
from trainer_quant_qa import QuestionAnsweringTrainer
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerFast,
|
||||
QDQBertConfig,
|
||||
QDQBertForQuestionAnswering,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import SchedulerType, get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.9.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
do_calib: bool = field(default=False, metadata={"help": "Whether to run calibration of quantization ranges."})
|
||||
num_calib_batch: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of batches for calibration. 0 will disable calibration "},
|
||||
)
|
||||
save_onnx: bool = field(default=False, metadata={"help": "Whether to save model to onnx."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
|
||||
"be faster on GPU but will be slower on TPU)."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
version_2_with_negative: bool = field(
|
||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||
)
|
||||
null_score_diff_threshold: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||
"Only useful when `version_2_with_negative=True`."
|
||||
},
|
||||
)
|
||||
doc_stride: int = field(
|
||||
default=128,
|
||||
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
|
||||
)
|
||||
n_best_size: int = field(
|
||||
default=20,
|
||||
metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
|
||||
)
|
||||
max_answer_length: int = field(
|
||||
default=30,
|
||||
metadata={
|
||||
"help": "The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.dataset_name is None
|
||||
and self.train_file is None
|
||||
and self.validation_file is None
|
||||
and self.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.test_file is not None:
|
||||
extension = self.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
# quant_trainer arguments
|
||||
quant_trainer.add_arguments(parser)
|
||||
|
||||
# if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# # If we pass only one argument to the script and it's the path to a json file,
|
||||
# # let's parse it to get our arguments.
|
||||
# model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
# else:
|
||||
|
||||
model_args, data_args, training_args, quant_trainer_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# setup QAT training args for scheduler (default to use cosine annealing learning rate schedule)
|
||||
training_args.lr_scheduler_type = SchedulerType.COSINE
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir)
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# set default quantization parameters before building model
|
||||
quant_trainer.set_default_quantizers(quant_trainer_args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config = QDQBertConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=True,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
model = QDQBertForQuestionAnswering.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Tokenizer check: this script requires a fast tokenizer.
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
raise ValueError(
|
||||
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
|
||||
"at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
|
||||
"requirement"
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# Preprocessing is slighlty different for training and evaluation.
|
||||
if training_args.do_train or model_args.do_calib:
|
||||
column_names = raw_datasets["train"].column_names
|
||||
elif training_args.do_eval or model_args.save_onnx:
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
else:
|
||||
column_names = raw_datasets["test"].column_names
|
||||
question_column_name = "question" if "question" in column_names else column_names[0]
|
||||
context_column_name = "context" if "context" in column_names else column_names[1]
|
||||
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
||||
|
||||
# Padding side determines if we do (question|context) or (context|question).
|
||||
pad_on_right = tokenizer.padding_side == "right"
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
# Training preprocessing
|
||||
def prepare_train_features(examples):
|
||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||
# in one example possible giving several features when a context is long, each of those features having a
|
||||
# context that overlaps a bit the context of the previous feature.
|
||||
tokenized_examples = tokenizer(
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
padding="max_length" if data_args.pad_to_max_length else False,
|
||||
)
|
||||
|
||||
# Since one example might give us several features if it has a long context, we need a map from a feature to
|
||||
# its corresponding example. This key gives us just that.
|
||||
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
||||
# The offset mappings will give us a map from token to character position in the original context. This will
|
||||
# help us compute the start_positions and end_positions.
|
||||
offset_mapping = tokenized_examples.pop("offset_mapping")
|
||||
|
||||
# Let's label those examples!
|
||||
tokenized_examples["start_positions"] = []
|
||||
tokenized_examples["end_positions"] = []
|
||||
|
||||
for i, offsets in enumerate(offset_mapping):
|
||||
# We will label impossible answers with the index of the CLS token.
|
||||
input_ids = tokenized_examples["input_ids"][i]
|
||||
cls_index = input_ids.index(tokenizer.cls_token_id)
|
||||
|
||||
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
|
||||
sequence_ids = tokenized_examples.sequence_ids(i)
|
||||
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = sample_mapping[i]
|
||||
answers = examples[answer_column_name][sample_index]
|
||||
# If no answers are given, set the cls_index as answer.
|
||||
if len(answers["answer_start"]) == 0:
|
||||
tokenized_examples["start_positions"].append(cls_index)
|
||||
tokenized_examples["end_positions"].append(cls_index)
|
||||
else:
|
||||
# Start/end character index of the answer in the text.
|
||||
start_char = answers["answer_start"][0]
|
||||
end_char = start_char + len(answers["text"][0])
|
||||
|
||||
# Start token index of the current span in the text.
|
||||
token_start_index = 0
|
||||
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
||||
token_start_index += 1
|
||||
|
||||
# End token index of the current span in the text.
|
||||
token_end_index = len(input_ids) - 1
|
||||
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
||||
token_end_index -= 1
|
||||
|
||||
# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
|
||||
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
|
||||
tokenized_examples["start_positions"].append(cls_index)
|
||||
tokenized_examples["end_positions"].append(cls_index)
|
||||
else:
|
||||
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
|
||||
# Note: we could go after the last offset if the answer is the last word (edge case).
|
||||
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
|
||||
token_start_index += 1
|
||||
tokenized_examples["start_positions"].append(token_start_index - 1)
|
||||
while offsets[token_end_index][1] >= end_char:
|
||||
token_end_index -= 1
|
||||
tokenized_examples["end_positions"].append(token_end_index + 1)
|
||||
|
||||
return tokenized_examples
|
||||
|
||||
if training_args.do_train or model_args.do_calib:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = raw_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
# We will select sample from whole data if agument is specified
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
# Create train feature from dataset
|
||||
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_train_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on train dataset",
|
||||
)
|
||||
if data_args.max_train_samples is not None:
|
||||
# Number of samples might increase during Feature Creation, We select only specified max samples
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
# Validation preprocessing
|
||||
def prepare_validation_features(examples):
|
||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||
# in one example possible giving several features when a context is long, each of those features having a
|
||||
# context that overlaps a bit the context of the previous feature.
|
||||
tokenized_examples = tokenizer(
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
padding="max_length" if data_args.pad_to_max_length else False,
|
||||
)
|
||||
|
||||
# Since one example might give us several features if it has a long context, we need a map from a feature to
|
||||
# its corresponding example. This key gives us just that.
|
||||
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
||||
|
||||
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
|
||||
# corresponding example_id and we will store the offset mappings.
|
||||
tokenized_examples["example_id"] = []
|
||||
|
||||
for i in range(len(tokenized_examples["input_ids"])):
|
||||
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
|
||||
sequence_ids = tokenized_examples.sequence_ids(i)
|
||||
context_index = 1 if pad_on_right else 0
|
||||
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = sample_mapping[i]
|
||||
tokenized_examples["example_id"].append(examples["id"][sample_index])
|
||||
|
||||
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
|
||||
# position is part of the context or not.
|
||||
tokenized_examples["offset_mapping"][i] = [
|
||||
(o if sequence_ids[k] == context_index else None)
|
||||
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
||||
]
|
||||
|
||||
return tokenized_examples
|
||||
|
||||
if training_args.do_eval or model_args.save_onnx:
|
||||
if "validation" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_examples = raw_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||
# Validation Feature Creation
|
||||
with training_args.main_process_first(desc="validation dataset map pre-processing"):
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on validation dataset",
|
||||
)
|
||||
if data_args.max_eval_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in raw_datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_examples = raw_datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
# We will select sample from whole data
|
||||
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
|
||||
# Predict Feature Creation
|
||||
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
|
||||
predict_dataset = predict_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on prediction dataset",
|
||||
)
|
||||
if data_args.max_predict_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
# collator.
|
||||
data_collator = (
|
||||
default_data_collator
|
||||
if data_args.pad_to_max_length
|
||||
else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
def post_processing_function(examples, features, predictions, stage="eval"):
|
||||
# Post-processing: we match the start logits and end logits to answers in the original context.
|
||||
predictions = postprocess_qa_predictions(
|
||||
examples=examples,
|
||||
features=features,
|
||||
predictions=predictions,
|
||||
version_2_with_negative=data_args.version_2_with_negative,
|
||||
n_best_size=data_args.n_best_size,
|
||||
max_answer_length=data_args.max_answer_length,
|
||||
null_score_diff_threshold=data_args.null_score_diff_threshold,
|
||||
output_dir=training_args.output_dir,
|
||||
log_level=log_level,
|
||||
prefix=stage,
|
||||
)
|
||||
# Format the result to the format the metric expects.
|
||||
if data_args.version_2_with_negative:
|
||||
formatted_predictions = [
|
||||
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
|
||||
]
|
||||
else:
|
||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train or model_args.do_calib else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval or model_args.save_onnx else None,
|
||||
eval_examples=eval_examples if training_args.do_eval or model_args.save_onnx else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
post_process_function=post_processing_function,
|
||||
compute_metrics=compute_metrics,
|
||||
quant_trainer_args=quant_trainer_args,
|
||||
)
|
||||
|
||||
# Calibration
|
||||
if model_args.do_calib:
|
||||
logger.info("*** Calibrate ***")
|
||||
results = trainer.calibrate()
|
||||
trainer.save_model()
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
|
||||
quant_trainer.configure_model(trainer.model, quant_trainer_args)
|
||||
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
quant_trainer.configure_model(trainer.model, quant_trainer_args, eval=True)
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
results = trainer.predict(predict_dataset, predict_examples)
|
||||
metrics = results.metrics
|
||||
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
kwargs["dataset_args"] = data_args.dataset_config_name
|
||||
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
||||
else:
|
||||
kwargs["dataset"] = data_args.dataset_name
|
||||
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
if model_args.save_onnx:
|
||||
logger.info("Exporting model to onnx")
|
||||
results = trainer.save_onnx(output_dir=training_args.output_dir)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,212 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
# Copyright 2021 NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A subclass of `Trainer` specific to Question-Answering tasks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import quant_trainer
|
||||
from transformers import Trainer, is_torch_tpu_available
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
|
||||
class QuestionAnsweringTrainer(Trainer):
|
||||
def __init__(self, *args, eval_examples=None, post_process_function=None, quant_trainer_args=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.eval_examples = eval_examples
|
||||
self.post_process_function = post_process_function
|
||||
self.quant_trainer_args = quant_trainer_args
|
||||
self.calib_num = 128 # default number of calibration samples
|
||||
|
||||
def get_calib_dataloader(self, calib_dataset=None):
|
||||
"""
|
||||
Returns the calibration dataloader :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Args:
|
||||
calib_dataset (:obj:`torch.utils.data.Dataset`, `optional`)
|
||||
"""
|
||||
if calib_dataset is None and self.calib_dataset is None:
|
||||
raise ValueError("Trainer: calibration requires an calib_dataset.")
|
||||
calib_dataset = calib_dataset if calib_dataset is not None else self.calib_dataset
|
||||
|
||||
calib_dataset = self._remove_unused_columns(calib_dataset, description="Calibration")
|
||||
|
||||
return DataLoader(
|
||||
calib_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def calibrate(self, calib_dataset=None):
|
||||
calib_dataset = self.train_dataset if calib_dataset is None else calib_dataset
|
||||
calib_dataloader = self.get_calib_dataloader(calib_dataset)
|
||||
|
||||
model = self.model
|
||||
quant_trainer.configure_model(model, self.quant_trainer_args, calib=True)
|
||||
model.eval()
|
||||
quant_trainer.enable_calibration(model)
|
||||
|
||||
logger.info("***** Running calibration *****")
|
||||
logger.info(f" Num examples = {self.calib_num}")
|
||||
logger.info(f" Batch size = {calib_dataloader.batch_size}")
|
||||
|
||||
for step, inputs in enumerate(calib_dataloader):
|
||||
# Prediction step
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only=True)
|
||||
if (step + 1) * calib_dataloader.batch_size >= self.calib_num:
|
||||
break
|
||||
|
||||
quant_trainer.finish_calibration(model, self.quant_trainer_args)
|
||||
self.model = model
|
||||
|
||||
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
|
||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
eval_examples = self.eval_examples if eval_examples is None else eval_examples
|
||||
|
||||
# Temporarily disable metric computation, we will do it in the loop here.
|
||||
compute_metrics = self.compute_metrics
|
||||
self.compute_metrics = None
|
||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
try:
|
||||
output = eval_loop(
|
||||
eval_dataloader,
|
||||
description="Evaluation",
|
||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
|
||||
if self.post_process_function is not None and self.compute_metrics is not None:
|
||||
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
|
||||
metrics = self.compute_metrics(eval_preds)
|
||||
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
self.log(metrics)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
return metrics
|
||||
|
||||
def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
|
||||
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
||||
|
||||
# Temporarily disable metric computation, we will do it in the loop here.
|
||||
compute_metrics = self.compute_metrics
|
||||
self.compute_metrics = None
|
||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
try:
|
||||
output = eval_loop(
|
||||
predict_dataloader,
|
||||
description="Prediction",
|
||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
|
||||
if self.post_process_function is None or self.compute_metrics is None:
|
||||
return output
|
||||
|
||||
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
|
||||
metrics = self.compute_metrics(predictions)
|
||||
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
||||
|
||||
def save_onnx(self, output_dir="./"):
|
||||
eval_dataset = self.eval_dataset
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
|
||||
batch = next(iter(eval_dataloader))
|
||||
|
||||
# saving device - to make it consistent
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# convert to tuple
|
||||
input_tuple = tuple(v.to(device) for k, v in batch.items())
|
||||
|
||||
logger.info("Converting model to be onnx compatible")
|
||||
from pytorch_quantization.nn import TensorQuantizer
|
||||
|
||||
TensorQuantizer.use_fb_fake_quant = True
|
||||
|
||||
model = self.model.to(device)
|
||||
|
||||
model.eval()
|
||||
model.float()
|
||||
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
quant_trainer.configure_model(model_to_save, self.quant_trainer_args)
|
||||
|
||||
output_model_file = os.path.join(output_dir, "model.onnx")
|
||||
logger.info(f"exporting model to {output_model_file}")
|
||||
|
||||
axes = {0: "batch_size", 1: "seq_len"}
|
||||
|
||||
torch.onnx.export(
|
||||
model_to_save,
|
||||
input_tuple,
|
||||
output_model_file,
|
||||
export_params=True,
|
||||
opset_version=13,
|
||||
do_constant_folding=True,
|
||||
input_names=["input_ids", "attention_mask", "token_type_ids"],
|
||||
output_names=["output_start_logits", "output_end_logits"],
|
||||
dynamic_axes={
|
||||
"input_ids": axes,
|
||||
"attention_mask": axes,
|
||||
"token_type_ids": axes,
|
||||
"output_start_logits": axes,
|
||||
"output_end_logits": axes,
|
||||
},
|
||||
verbose=True,
|
||||
)
|
||||
logger.info("onnx export finished")
|
||||
427
examples/research_projects/quantization-qdqbert/utils_qa.py
Normal file
427
examples/research_projects/quantization-qdqbert/utils_qa.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Post-processing utilities for question answering.
|
||||
"""
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def postprocess_qa_predictions(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
null_score_diff_threshold: float = 0.0,
|
||||
output_dir: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
log_level: Optional[int] = logging.WARNING,
|
||||
):
|
||||
"""
|
||||
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
|
||||
original contexts. This is the base postprocessing functions for models that only return start and end logits.
|
||||
|
||||
Args:
|
||||
examples: The non-preprocessed dataset (see the main script for more information).
|
||||
features: The processed dataset (see the main script for more information).
|
||||
predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
|
||||
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
|
||||
first dimension must match the number of elements of :obj:`features`.
|
||||
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the underlying dataset contains examples with no answers.
|
||||
n_best_size (:obj:`int`, `optional`, defaults to 20):
|
||||
The total number of n-best predictions to generate when looking for an answer.
|
||||
max_answer_length (:obj:`int`, `optional`, defaults to 30):
|
||||
The maximum length of an answer that can be generated. This is needed because the start and end predictions
|
||||
are not conditioned on one another.
|
||||
null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
|
||||
The threshold used to select the null answer: if the best answer has a score that is less than the score of
|
||||
the null answer minus this threshold, the null answer is selected for this example (note that the score of
|
||||
the null answer for an example giving several features is the minimum of the scores for the null answer on
|
||||
each feature: all features must be aligned on the fact they `want` to predict a null answer).
|
||||
|
||||
Only useful when :obj:`version_2_with_negative` is :obj:`True`.
|
||||
output_dir (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
|
||||
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
|
||||
answers, are saved in `output_dir`.
|
||||
prefix (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
||||
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
|
||||
``logging`` log level (e.g., ``logging.WARNING``)
|
||||
"""
|
||||
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
|
||||
all_start_logits, all_end_logits = predictions
|
||||
|
||||
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
features_per_example = collections.defaultdict(list)
|
||||
for i, feature in enumerate(features):
|
||||
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
|
||||
|
||||
# The dictionaries we have to fill.
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
if version_2_with_negative:
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
# Logging.
|
||||
logger.setLevel(log_level)
|
||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(tqdm(examples)):
|
||||
# Those are the indices of the features associated to the current example.
|
||||
feature_indices = features_per_example[example_index]
|
||||
|
||||
min_null_prediction = None
|
||||
prelim_predictions = []
|
||||
|
||||
# Looping through all the features associated to the current example.
|
||||
for feature_index in feature_indices:
|
||||
# We grab the predictions of the model for this feature.
|
||||
start_logits = all_start_logits[feature_index]
|
||||
end_logits = all_end_logits[feature_index]
|
||||
# This is what will allow us to map some the positions in our logits to span of texts in the original
|
||||
# context.
|
||||
offset_mapping = features[feature_index]["offset_mapping"]
|
||||
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
|
||||
# available in the current feature.
|
||||
token_is_max_context = features[feature_index].get("token_is_max_context", None)
|
||||
|
||||
# Update minimum null prediction.
|
||||
feature_null_score = start_logits[0] + end_logits[0]
|
||||
if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
|
||||
min_null_prediction = {
|
||||
"offsets": (0, 0),
|
||||
"score": feature_null_score,
|
||||
"start_logit": start_logits[0],
|
||||
"end_logit": end_logits[0],
|
||||
}
|
||||
|
||||
# Go through all possibilities for the `n_best_size` greater start and end logits.
|
||||
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
|
||||
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
|
||||
for start_index in start_indexes:
|
||||
for end_index in end_indexes:
|
||||
# Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
|
||||
# to part of the input_ids that are not in the context.
|
||||
if (
|
||||
start_index >= len(offset_mapping)
|
||||
or end_index >= len(offset_mapping)
|
||||
or offset_mapping[start_index] is None
|
||||
or offset_mapping[end_index] is None
|
||||
):
|
||||
continue
|
||||
# Don't consider answers with a length that is either < 0 or > max_answer_length.
|
||||
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
||||
continue
|
||||
# Don't consider answer that don't have the maximum context available (if such information is
|
||||
# provided).
|
||||
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
{
|
||||
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
|
||||
"score": start_logits[start_index] + end_logits[end_index],
|
||||
"start_logit": start_logits[start_index],
|
||||
"end_logit": end_logits[end_index],
|
||||
}
|
||||
)
|
||||
if version_2_with_negative:
|
||||
# Add the minimum null prediction
|
||||
prelim_predictions.append(min_null_prediction)
|
||||
null_score = min_null_prediction["score"]
|
||||
|
||||
# Only keep the best `n_best_size` predictions.
|
||||
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
||||
|
||||
# Add back the minimum null prediction if it was removed because of its low score.
|
||||
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
|
||||
predictions.append(min_null_prediction)
|
||||
|
||||
# Use the offsets to gather the answer text in the original context.
|
||||
context = example["context"]
|
||||
for pred in predictions:
|
||||
offsets = pred.pop("offsets")
|
||||
pred["text"] = context[offsets[0] : offsets[1]]
|
||||
|
||||
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
||||
# failure.
|
||||
if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
|
||||
predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
|
||||
|
||||
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
||||
# the LogSumExp trick).
|
||||
scores = np.array([pred.pop("score") for pred in predictions])
|
||||
exp_scores = np.exp(scores - np.max(scores))
|
||||
probs = exp_scores / exp_scores.sum()
|
||||
|
||||
# Include the probabilities in our predictions.
|
||||
for prob, pred in zip(probs, predictions):
|
||||
pred["probability"] = prob
|
||||
|
||||
# Pick the best prediction. If the null answer is not possible, this is easy.
|
||||
if not version_2_with_negative:
|
||||
all_predictions[example["id"]] = predictions[0]["text"]
|
||||
else:
|
||||
# Otherwise we first need to find the best non-empty prediction.
|
||||
i = 0
|
||||
while predictions[i]["text"] == "":
|
||||
i += 1
|
||||
best_non_null_pred = predictions[i]
|
||||
|
||||
# Then we compare to the null prediction using the threshold.
|
||||
score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
|
||||
scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
|
||||
if score_diff > null_score_diff_threshold:
|
||||
all_predictions[example["id"]] = ""
|
||||
else:
|
||||
all_predictions[example["id"]] = best_non_null_pred["text"]
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
)
|
||||
nbest_file = os.path.join(
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||
)
|
||||
if version_2_with_negative:
|
||||
null_odds_file = os.path.join(
|
||||
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
||||
)
|
||||
|
||||
logger.info(f"Saving predictions to {prediction_file}.")
|
||||
with open(prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
logger.info(f"Saving nbest_preds to {nbest_file}.")
|
||||
with open(nbest_file, "w") as writer:
|
||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||
if version_2_with_negative:
|
||||
logger.info(f"Saving null_odds to {null_odds_file}.")
|
||||
with open(null_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
return all_predictions
|
||||
|
||||
|
||||
def postprocess_qa_predictions_with_beam_search(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
start_n_top: int = 5,
|
||||
end_n_top: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
log_level: Optional[int] = logging.WARNING,
|
||||
):
|
||||
"""
|
||||
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
|
||||
original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as
|
||||
cls token predictions.
|
||||
|
||||
Args:
|
||||
examples: The non-preprocessed dataset (see the main script for more information).
|
||||
features: The processed dataset (see the main script for more information).
|
||||
predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
|
||||
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
|
||||
first dimension must match the number of elements of :obj:`features`.
|
||||
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the underlying dataset contains examples with no answers.
|
||||
n_best_size (:obj:`int`, `optional`, defaults to 20):
|
||||
The total number of n-best predictions to generate when looking for an answer.
|
||||
max_answer_length (:obj:`int`, `optional`, defaults to 30):
|
||||
The maximum length of an answer that can be generated. This is needed because the start and end predictions
|
||||
are not conditioned on one another.
|
||||
start_n_top (:obj:`int`, `optional`, defaults to 5):
|
||||
The number of top start logits too keep when searching for the :obj:`n_best_size` predictions.
|
||||
end_n_top (:obj:`int`, `optional`, defaults to 5):
|
||||
The number of top end logits too keep when searching for the :obj:`n_best_size` predictions.
|
||||
output_dir (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
|
||||
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
|
||||
answers, are saved in `output_dir`.
|
||||
prefix (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
||||
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
|
||||
``logging`` log level (e.g., ``logging.WARNING``)
|
||||
"""
|
||||
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
|
||||
|
||||
assert len(predictions[0]) == len(
|
||||
features
|
||||
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
features_per_example = collections.defaultdict(list)
|
||||
for i, feature in enumerate(features):
|
||||
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
|
||||
|
||||
# The dictionaries we have to fill.
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
|
||||
|
||||
# Logging.
|
||||
logger.setLevel(log_level)
|
||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(tqdm(examples)):
|
||||
# Those are the indices of the features associated to the current example.
|
||||
feature_indices = features_per_example[example_index]
|
||||
|
||||
min_null_score = None
|
||||
prelim_predictions = []
|
||||
|
||||
# Looping through all the features associated to the current example.
|
||||
for feature_index in feature_indices:
|
||||
# We grab the predictions of the model for this feature.
|
||||
start_log_prob = start_top_log_probs[feature_index]
|
||||
start_indexes = start_top_index[feature_index]
|
||||
end_log_prob = end_top_log_probs[feature_index]
|
||||
end_indexes = end_top_index[feature_index]
|
||||
feature_null_score = cls_logits[feature_index]
|
||||
# This is what will allow us to map some the positions in our logits to span of texts in the original
|
||||
# context.
|
||||
offset_mapping = features[feature_index]["offset_mapping"]
|
||||
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
|
||||
# available in the current feature.
|
||||
token_is_max_context = features[feature_index].get("token_is_max_context", None)
|
||||
|
||||
# Update minimum null prediction
|
||||
if min_null_score is None or feature_null_score < min_null_score:
|
||||
min_null_score = feature_null_score
|
||||
|
||||
# Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits.
|
||||
for i in range(start_n_top):
|
||||
for j in range(end_n_top):
|
||||
start_index = int(start_indexes[i])
|
||||
j_index = i * end_n_top + j
|
||||
end_index = int(end_indexes[j_index])
|
||||
# Don't consider out-of-scope answers (last part of the test should be unnecessary because of the
|
||||
# p_mask but let's not take any risk)
|
||||
if (
|
||||
start_index >= len(offset_mapping)
|
||||
or end_index >= len(offset_mapping)
|
||||
or offset_mapping[start_index] is None
|
||||
or offset_mapping[end_index] is None
|
||||
):
|
||||
continue
|
||||
# Don't consider answers with a length negative or > max_answer_length.
|
||||
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
||||
continue
|
||||
# Don't consider answer that don't have the maximum context available (if such information is
|
||||
# provided).
|
||||
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
{
|
||||
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
|
||||
"score": start_log_prob[i] + end_log_prob[j_index],
|
||||
"start_log_prob": start_log_prob[i],
|
||||
"end_log_prob": end_log_prob[j_index],
|
||||
}
|
||||
)
|
||||
|
||||
# Only keep the best `n_best_size` predictions.
|
||||
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
||||
|
||||
# Use the offsets to gather the answer text in the original context.
|
||||
context = example["context"]
|
||||
for pred in predictions:
|
||||
offsets = pred.pop("offsets")
|
||||
pred["text"] = context[offsets[0] : offsets[1]]
|
||||
|
||||
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
||||
# failure.
|
||||
if len(predictions) == 0:
|
||||
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
|
||||
|
||||
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
||||
# the LogSumExp trick).
|
||||
scores = np.array([pred.pop("score") for pred in predictions])
|
||||
exp_scores = np.exp(scores - np.max(scores))
|
||||
probs = exp_scores / exp_scores.sum()
|
||||
|
||||
# Include the probabilities in our predictions.
|
||||
for prob, pred in zip(probs, predictions):
|
||||
pred["probability"] = prob
|
||||
|
||||
# Pick the best prediction and set the probability for the null answer.
|
||||
all_predictions[example["id"]] = predictions[0]["text"]
|
||||
if version_2_with_negative:
|
||||
scores_diff_json[example["id"]] = float(min_null_score)
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
)
|
||||
nbest_file = os.path.join(
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||
)
|
||||
if version_2_with_negative:
|
||||
null_odds_file = os.path.join(
|
||||
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
||||
)
|
||||
|
||||
logger.info(f"Saving predictions to {prediction_file}.")
|
||||
with open(prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
logger.info(f"Saving nbest_preds to {nbest_file}.")
|
||||
with open(nbest_file, "w") as writer:
|
||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||
if version_2_with_negative:
|
||||
logger.info(f"Saving null_odds to {null_odds_file}.")
|
||||
with open(null_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
return all_predictions, scores_diff_json
|
||||
Reference in New Issue
Block a user