I-BERT model support (#10153)

* IBertConfig, IBertTokentizer added

* IBert Model names moified

* tokenizer bugfix

* embedding -> QuantEmbedding

* quant utils added

* quant_mode added to configuration

* QuantAct added, Embedding layer + QuantAct addition

* QuantAct added

* unused path removed, QKV quantized

* self attention layer all quantized, except softmax

* temporarl commit

* all liner layers quantized

* quant_utils bugfix

* bugfix: requantization missing

* IntGELU added

* IntSoftmax added

* LayerNorm implemented

* LayerNorm implemented all

* names changed: roberta->ibert

* config not inherit from ROberta

* No support for CausalLM

* static quantization added, quantize_model.py removed

* import modules uncommented

* copyrights fixed

* minor bugfix

* quant_modules, quant_utils merged as one file

* import * fixed

* unused runfile removed

* make style run

* configutration.py docstring fixed

* refactoring: comments removed, function name fixed

* unused dependency removed

* typo fixed

* comments(Copied from), assertion string added

* refactoring: super(..) -> super(), etc.

* refactoring

* refarctoring

* make style

* refactoring

* cuda -> to(x.device)

* weight initialization removed

* QuantLinear set_param removed

* QuantEmbedding set_param removed

* IntLayerNorm set_param removed

* assert string added

* assertion error message fixed

* is_decoder removed

* enc-dec arguments/functions removed

* Converter removed

* quant_modules docstring fixed

* conver_slow_tokenizer rolled back

* quant_utils docstring fixed

* unused aruments e.g. use_cache removed from config

* weight initialization condition fixed

* x_min, x_max initialized with small values to avoid div-zero exceptions

* testing code for ibert

* test emb, linear, gelu, softmax added

* test ln and act added

* style reformatted

* force_dequant added

* error tests overrided

* make style

* Style + Docs

* force dequant tests added

* Fix fast tokenizer in init

* Fix doc

* Remove space

* docstring, IBertConfig, chunk_size

* test_modeling_ibert refactoring

* quant_modules.py refactoring

* e2e integration test added

* tokenizers removed

* IBertConfig added to tokenizer_auto.py

* bugfix

* fix docs & test

* fix style num 2

* final fixes

Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sehoon Kim
2021-02-26 00:06:42 +09:00
committed by GitHub
parent cb38ffcc5e
commit 63645b3b11
12 changed files with 3279 additions and 0 deletions

View File

@@ -40,6 +40,7 @@ from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
@@ -110,6 +111,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
@@ -123,6 +125,7 @@ CONFIG_MAPPING = OrderedDict(
("led", LEDConfig),
("blenderbot-small", BlenderbotSmallConfig),
("retribert", RetriBertConfig),
("ibert", IBertConfig),
("mt5", MT5Config),
("t5", T5Config),
("mobilebert", MobileBertConfig),
@@ -173,6 +176,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("led", "LED"),
("blenderbot-small", "BlenderbotSmall"),
("retribert", "RetriBERT"),
("ibert", "I-BERT"),
("t5", "T5"),
("mobilebert", "MobileBERT"),
("distilbert", "DistilBERT"),

View File

@@ -129,6 +129,14 @@ from ..funnel.modeling_funnel import (
FunnelModel,
)
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..ibert.modeling_ibert import (
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
)
from ..layoutlm.modeling_layoutlm import (
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
@@ -270,6 +278,7 @@ from .configuration_auto import (
FSMTConfig,
FunnelConfig,
GPT2Config,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
@@ -347,6 +356,7 @@ MODEL_MAPPING = OrderedDict(
(MPNetConfig, MPNetModel),
(TapasConfig, TapasModel),
(MarianConfig, MarianModel),
(IBertConfig, IBertModel),
]
)
@@ -379,6 +389,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(FunnelConfig, FunnelForPreTraining),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
(IBertConfig, IBertForMaskedLM),
]
)
@@ -418,6 +429,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(IBertConfig, IBertForMaskedLM),
]
)
@@ -476,6 +488,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(IBertConfig, IBertForMaskedLM),
]
)
@@ -529,6 +542,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(TransfoXLConfig, TransfoXLForSequenceClassification),
(MPNetConfig, MPNetForSequenceClassification),
(TapasConfig, TapasForSequenceClassification),
(IBertConfig, IBertForSequenceClassification),
]
)
@@ -558,6 +572,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(MPNetConfig, MPNetForQuestionAnswering),
(DebertaConfig, DebertaForQuestionAnswering),
(DebertaV2Config, DebertaV2ForQuestionAnswering),
(IBertConfig, IBertForQuestionAnswering),
]
)
@@ -591,6 +606,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(MPNetConfig, MPNetForTokenClassification),
(DebertaConfig, DebertaForTokenClassification),
(DebertaV2Config, DebertaV2ForTokenClassification),
(IBertConfig, IBertForTokenClassification),
]
)
@@ -613,6 +629,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(FlaubertConfig, FlaubertForMultipleChoice),
(FunnelConfig, FunnelForMultipleChoice),
(MPNetConfig, MPNetForMultipleChoice),
(IBertConfig, IBertForMultipleChoice),
]
)

View File

@@ -75,6 +75,7 @@ from .configuration_auto import (
FSMTConfig,
FunnelConfig,
GPT2Config,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
@@ -244,6 +245,7 @@ TOKENIZER_MAPPING = OrderedDict(
(TapasConfig, (TapasTokenizer, None)),
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
]
)

View File

@@ -0,0 +1,69 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
}
if is_torch_available():
_import_structure["modeling_ibert"] = [
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"IBertForMaskedLM",
"IBertForMultipleChoice",
"IBertForQuestionAnswering",
"IBertForSequenceClassification",
"IBertForTokenClassification",
"IBertModel",
]
if TYPE_CHECKING:
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
if is_torch_available():
from .modeling_ibert import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)

View File

@@ -0,0 +1,125 @@
# coding=utf-8
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" I-BERT configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"kssteven/ibert-roberta-base": "https://huggingface.co/kssteven/ibert-roberta-base/resolve/main/config.json",
"kssteven/ibert-roberta-large": "https://huggingface.co/kssteven/ibert-roberta-large/resolve/main/config.json",
"kssteven/ibert-roberta-large-mnli": "https://huggingface.co/kssteven/ibert-roberta-large-mnli/resolve/main/config.json",
}
class IBertConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a :class:`~transformers.IBertModel`. It is used to
instantiate a I-BERT model according to the specified arguments,
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the I-BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.IBertModel`
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.IBertModel`
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to quantize the model or not.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize specific nonlinear layer. Dequatized layers are then executed with full precision.
:obj:`"none"`, :obj:`"gelu"`, :obj:`"softmax"`, :obj:`"layernorm"` and :obj:`"nonlinear"` are supported. As
deafult, it is set as :obj:`"none"`, which does not dequantize any layers. Please specify :obj:`"gelu"`,
:obj:`"softmax"`, or :obj:`"layernorm"` to dequantize GELU, Softmax, or LayerNorm, respectively.
:obj:`"nonlinear"` will dequantize all nonlinear layers, i.e., GELU, Softmax, and LayerNorm.
"""
model_type = "ibert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
position_embedding_type="absolute",
quant_mode=False,
force_dequant="none",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.quant_mode = quant_mode
self.force_dequant = force_dequant

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,829 @@
# coding=utf-8
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import decimal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from ...utils import logging
logger = logging.get_logger(__name__)
class QuantEmbedding(nn.Module):
"""
Quantized version of :obj:`torch.nn.Embedding`. Adds quantization-specific arguments on top of
:obj:`torch.nn.Embedding`.
Args:
weight_bit (:obj:`int`, `optiona`l, defaults to :obj:`8`):
Bitwidth for the quantized weight.
momentum (:obj:`float`, `optional, defaults to :obj:`0.95`):
Momentum for updating the activation quantization range.
quant_mode (:obj:`bool`, `optional, defaults to :obj:`False`):
Whether or not the layer is quantized.
"""
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
weight_bit=8,
momentum=0.95,
quant_mode=False,
):
super().__init__()
self.num_ = num_embeddings
self.dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim]))
self.register_buffer("weight_scaling_factor", torch.zeros(1))
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
self.weight_bit = weight_bit
self.momentum = momentum
self.quant_mode = quant_mode
self.percentile_mode = False
self.weight_function = SymmetricQuantFunction.apply
def forward(self, x, positions=None, incremental_state=None):
if not self.quant_mode:
return (
F.embedding(
x,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
),
None,
)
w = self.weight
w_transform = w.data.detach()
w_min = w_transform.min().expand(1)
w_max = w_transform.max().expand(1)
self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False)
self.weight_integer = self.weight_function(
self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
)
emb_int = F.embedding(
x,
self.weight_integer,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return emb_int * self.weight_scaling_factor, self.weight_scaling_factor
class QuantAct(nn.Module):
"""
Quantizes the given activation.
Args:
activation_bit (:obj:`int`):
Bitwidth for the quantized activation.
act_range_momentum (:obj:`float`, `optional`, defaults to :obj:`0.95`):
Momentum for updating the activation quantization range.
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to or not use channel-wise quantization.
channel_len (:obj:`int`, `optional`, defaults to :obj:`None`):
Specify the channel length when set the `per_channel` True.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
"""
def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False):
super().__init__()
self.activation_bit = activation_bit
self.act_range_momentum = act_range_momentum
self.quant_mode = quant_mode
self.per_channel = per_channel
self.percentile = False
self.act_function = SymmetricQuantFunction.apply
if not self.per_channel:
self.register_buffer("x_min", torch.zeros(1))
self.register_buffer("x_max", torch.zeros(1))
self.register_buffer("act_scaling_factor", torch.zeros(1))
self.x_min -= 1e-5
self.x_max += 1e-5
else:
raise NotImplementedError("per-channel mode is not currently supported for activation.")
def __repr__(self):
return (
"{0}(activation_bit={1}, "
"quant_mode: {2}, Act_min: {3:.2f}, "
"Act_max: {4:.2f})".format(
self.__class__.__name__, self.activation_bit, self.quant_mode, self.x_min.item(), self.x_max.item()
)
)
def forward(
self,
x,
pre_act_scaling_factor=None,
identity=None,
identity_scaling_factor=None,
specified_min=None,
specified_max=None,
):
x_act = x if identity is None else identity + x
# collect runnng stats if traiing
if self.training:
assert not self.percentile, "percentile mode is not currently supported for activation."
assert not self.per_channel, "per-channel mode is not currently supported for activation."
x_min = x_act.data.min()
x_max = x_act.data.max()
assert (
x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0
), "NaN detected when computing min/max of the activation"
# Initialization
if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5:
self.x_min = self.x_min + x_min
self.x_max = self.x_max + x_max
# exponential moving average (EMA)
# use momentum to prevent the quantized values change greatly every iteration
elif self.act_range_momentum == -1:
self.x_min = torch.min(self.x_min, x_min)
self.x_max = torch.max(self.x_max, x_max)
else:
self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
if not self.quant_mode:
return x_act, None
x_min = self.x_min if specified_min is None else specified_min
x_max = self.x_max if specified_max is None else specified_max
self.act_scaling_factor = symmetric_linear_quantization_params(
self.activation_bit, x_min, x_max, per_channel=self.per_channel
)
if pre_act_scaling_factor is None:
# this is for the input quantization
quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor)
else:
quant_act_int = FixedPointMul.apply(
x,
pre_act_scaling_factor,
self.activation_bit,
self.act_scaling_factor,
identity,
identity_scaling_factor,
)
correct_output_scale = self.act_scaling_factor.view(-1)
return quant_act_int * correct_output_scale, self.act_scaling_factor
class QuantLinear(nn.Module):
"""
Quantized version of :obj:`torch.nn.Linear`. Adds quantization-specific arguments on top of :obj:`torch.nn.Linear`.
Args:
weight_bit (:obj:`int`, `optional`, defaults to :obj:`8`):
Bitwidth for the quantized weight.
bias_bit (:obj:`int`, `optional`, defaults to :obj:`32`):
Bitwidth for the quantized bias.
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use channel-wise quantization.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
"""
def __init__(
self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.zeros([out_features, in_features]))
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
self.register_buffer("bias_integer", torch.zeros_like(self.bias))
self.weight_bit = weight_bit
self.quant_mode = quant_mode
self.per_channel = per_channel
self.bias_bit = bias_bit
self.quant_mode = quant_mode
self.percentile_mode = False
self.weight_function = SymmetricQuantFunction.apply
def __repr__(self):
s = super().__repr__()
s = "(" + s + " weight_bit={}, quant_mode={})".format(self.weight_bit, self.quant_mode)
return s
def forward(self, x, prev_act_scaling_factor=None):
if not self.quant_mode:
return F.linear(x, weight=self.weight, bias=self.bias), None
# assert that prev_act_scaling_factor is a scalar tensor
assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
"Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
"Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
)
w = self.weight
w_transform = w.data.detach()
if self.per_channel:
w_min, _ = torch.min(w_transform, dim=1, out=None)
w_max, _ = torch.max(w_transform, dim=1, out=None)
else:
w_min = w_transform.min().expand(1)
w_max = w_transform.max().expand(1)
self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel)
self.weight_integer = self.weight_function(
self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
)
bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
if self.bias is not None:
self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor)
prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
x_int = x / prev_act_scaling_factor
return (
F.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
bias_scaling_factor,
)
class IntGELU(nn.Module):
"""
Quantized version of :obj:`torch.nn.GELU`. Adds quantization-specific arguments on top of :obj:`torch.nn.GELU`.
Args:
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize the layer if either "gelu" or "nonlinear" is given.
"""
def __init__(self, quant_mode=True, force_dequant="none"):
super().__init__()
self.quant_mode = quant_mode
if force_dequant in ["nonlinear", "gelu"]:
logger.info("Force dequantize gelu")
self.quant_mode = False
if not self.quant_mode:
self.activation_fn = nn.GELU()
self.k = 1.4142
self.const = 14 # dummy integer constant
self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c
self.coeff[2] /= self.coeff[0]
def int_erf(self, x_int, scaling_factor):
b_int = torch.floor(self.coeff[1] / scaling_factor)
c_int = torch.floor(self.coeff[2] / scaling_factor ** 2)
sign = torch.sign(x_int)
abs_int = torch.min(torch.abs(x_int), -b_int)
y_int = sign * ((abs_int + b_int) ** 2 + c_int)
scaling_factor = scaling_factor ** 2 * self.coeff[0]
# avoid overflow
y_int = floor_ste.apply(y_int / 2 ** self.const)
scaling_factor = scaling_factor * 2 ** self.const
return y_int, scaling_factor
def forward(self, x, scaling_factor=None):
if not self.quant_mode:
return self.activation_fn(x), None
x_int = x / scaling_factor
sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k)
shift_int = 1.0 // sigmoid_scaling_factor
x_int = x_int * (sigmoid_int + shift_int)
scaling_factor = scaling_factor * sigmoid_scaling_factor / 2
return x_int * scaling_factor, scaling_factor
class IntSoftmax(nn.Module):
"""
Quantized version of :obj:`torch.nn.Softmax`. Adds quantization-specific arguments on top of
:obj:`torch.nn.Softmax`.
Args:
output_bit (:obj:`int`):
Bitwidth for the layer output activation.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize the layer if either "softmax" or "nonlinear" is given.
"""
def __init__(self, output_bit, quant_mode=False, force_dequant="none"):
super().__init__()
self.output_bit = output_bit
self.max_bit = 32
self.quant_mode = quant_mode
if force_dequant in ["nonlinear", "softmax"]:
logger.info("Force dequantize softmax")
self.quant_mode = False
self.act = QuantAct(16, quant_mode=self.quant_mode)
self.x0 = -0.6931 # -ln2
self.const = 30 # dummy integer constant
self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c
self.coef[1] /= self.coef[0]
self.coef[2] /= self.coef[0]
def int_polynomial(self, x_int, scaling_factor):
with torch.no_grad():
b_int = torch.floor(self.coef[1] / scaling_factor)
c_int = torch.floor(self.coef[2] / scaling_factor ** 2)
z = (x_int + b_int) * x_int + c_int
scaling_factor = self.coef[0] * scaling_factor ** 2
return z, scaling_factor
def int_exp(self, x_int, scaling_factor):
with torch.no_grad():
x0_int = torch.floor(self.x0 / scaling_factor)
x_int = torch.max(x_int, self.const * x0_int)
q = floor_ste.apply(x_int / x0_int)
r = x_int - x0_int * q
exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
scaling_factor = exp_scaling_factor / 2 ** self.const
return exp_int, scaling_factor
def forward(self, x, scaling_factor):
if not self.quant_mode:
return nn.Softmax(dim=-1)(x), None
x_int = x / scaling_factor
x_int_max, _ = x_int.max(dim=-1, keepdim=True)
x_int = x_int - x_int_max
exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor)
# Avoid overflow
exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor)
exp_int = exp / exp_scaling_factor
exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
factor = floor_ste.apply(2 ** self.max_bit / exp_int_sum)
exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
scaling_factor = 1 / 2 ** self.output_bit
return exp_int * scaling_factor, scaling_factor
class IntLayerNorm(nn.Module):
"""
Quantized version of :obj:`torch.nn.LayerNorm`. Adds quantization-specific arguments on top of
:obj:`torch.nn.LayerNorm`.
Args:
output_bit (:obj:`int`, `optional`, defaults to :obj:`8`):
Bitwidth for the layer output activation.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize the layer if either "layernorm" or "nonlinear" is given.
"""
def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"):
super().__init__()
self.normalized_shape = normalized_shape
self.eps = eps
self.weight = nn.Parameter(torch.zeros(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.quant_mode = quant_mode
if force_dequant in ["nonlinear", "layernorm"]:
logger.info("Force dequantize layernorm")
self.quant_mode = False
self.register_buffer("shift", torch.zeros(1))
self.output_bit = output_bit
self.max_bit = 32
self.dim_sqrt = None
self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode)
def set_shift(self, y_int):
with torch.no_grad():
y_sq_int = y_int ** 2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
shift = (torch.log2(torch.sqrt(var_int / 2 ** self.max_bit)).ceil()).max()
shift_old = self.shift
self.shift = torch.max(self.shift, shift)
logger.info("Dynamic shift adjustment: {} -> {}".format(int(shift_old), int(self.shift)))
def overflow_fallback(self, y_int):
"""
This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
to avoid overflow in the subsequent runs.
"""
self.set_shift(y_int) # adjusts `self.shift`
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift)
y_sq_int = y_int_shifted ** 2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
return var_int
def forward(self, x, scaling_factor=None):
if not self.quant_mode:
mean = x.mean(axis=2, keepdim=True)
y = x - mean
var = torch.mean(y ** 2, axis=2, keepdim=True)
x = y / torch.sqrt(self.eps + var)
x = x * self.weight + self.bias
return x, None
# compute sqrt of the feature dimension if it is the first run
if self.dim_sqrt is None:
n = torch.tensor(x.shape[2], dtype=torch.float)
self.dim_sqrt = torch.sqrt(n).to(x.device)
# Normalization: computes mean and variance(std)
x_int = x / scaling_factor
mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
y_int = x_int - mean_int
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift)
y_sq_int = y_int_shifted ** 2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
# overflow handling in training time
if self.training:
# if overflow is detected
if var_int.max() >= 2 ** self.max_bit:
var_int = self.overflow_fallback(y_int)
assert var_int.max() < 2 ** self.max_bit + 0.1, (
"Error detected in overflow handling: "
"`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
)
# To be replaced with integer-sqrt kernel that produces the same output
std_int = floor_ste.apply(torch.sqrt(var_int)) * 2 ** self.shift
factor = floor_ste.apply(2 ** 31 / std_int)
y_int = floor_ste.apply(y_int * factor / 2)
scaling_factor = self.dim_sqrt / 2 ** 30
# scaling and shifting
bias = self.bias.data.detach() / (self.weight.data.detach())
bias_int = floor_ste.apply(bias / scaling_factor)
y_int = y_int + bias_int
scaling_factor = scaling_factor * self.weight
x = y_int * scaling_factor
return x, scaling_factor
def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False):
"""
Calculate the percentile max and min values in a given tensor
Args:
input (:obj:`torch.Tensor`):
The target tensor to calculate percentile max and min.
lower_percentile (:obj:`float`):
If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
upper_percentile (:obj:`float`):
If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
output_tensor (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, this function returns tensors, otherwise it returns values.
Returns:
:obj:`Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of `input`
"""
input_length = input.shape[0]
lower_index = round(input_length * (1 - lower_percentile * 0.01))
upper_index = round(input_length * upper_percentile * 0.01)
upper_bound = torch.kthvalue(input, k=upper_index).values
if lower_percentile == 0:
lower_bound = upper_bound * 0
# lower_index += 1
else:
lower_bound = -torch.kthvalue(-input, k=lower_index).values
if not output_tensor:
lower_bound = lower_bound.item()
upper_bound = upper_bound.item()
return lower_bound, upper_bound
def linear_quantize(input, scale, zero_point, inplace=False):
"""
Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
Args:
input (:obj:`torch.Tensor`):
Single-precision input tensor to be quantized.
scale (:obj:`torch.Tensor`):
Scaling factor for quantization.
zero_pint (:obj:`torch.Tensor`):
Shift for quantization.
inplace (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to compute inplace or not.
Returns:
:obj:`torch.Tensor`: Linearly quantized value of `input` according to `scale` and `zero_point`.
"""
# reshape scale and zeropoint for convolutional weights and activation
if len(input.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(input.shape) == 2:
scale = scale.view(-1, 1)
zero_point = zero_point.view(-1, 1)
else:
scale = scale.view(-1)
zero_point = zero_point.view(-1)
# quantized = float / scale + zero_point
if inplace:
input.mul_(1.0 / scale).add_(zero_point).round_()
return input
return torch.round(1.0 / scale * input + zero_point)
def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False):
"""
Compute the scaling factor with the given quantization range for symmetric quantization.
Args:
saturation_min (:obj:`torch.Tensor`):
Lower bound for quantization range.
saturation_max (:obj:`torch.Tensor`):
Upper bound for quantization range.
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to or not use channel-wise quantization.
Returns:
:obj:`torch.Tensor`: Scaling factor that linearly quantizes the given range between `saturation_min` and
`saturation_max`.
"""
# in this part, we do not need any gradient computation,
# in order to enfore this, we put torch.no_grad()
with torch.no_grad():
n = 2 ** (num_bits - 1) - 1
if per_channel:
scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1)
scale = torch.clamp(scale, min=1e-8) / n
else:
scale = max(saturation_min.abs(), saturation_max.abs())
scale = torch.clamp(scale, min=1e-8) / n
return scale
class SymmetricQuantFunction(Function):
"""
Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
"""
@staticmethod
def forward(ctx, x, k, percentile_mode, scale):
"""
Args:
x (:obj:`torch.Tensor`):
Floating point tensor to be quantized.
k (:obj:`int`):
Quantization bitwidth.
percentile_mode (:obj:`bool`):
Whether or not to use percentile calibration.
scale (:obj:`torch.Tensor`):
Pre-calculated scaling factor for `x`. Note that the current implementation of SymmetricQuantFunction
requires pre-calculated scaling factor.
Returns:
:obj:`torch.Tensor`: Symmetric-quantized value of `input`.
"""
zero_point = torch.tensor(0.0).to(scale.device)
n = 2 ** (k - 1) - 1
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
ctx.scale = scale
return new_quant_x
@staticmethod
def backward(ctx, grad_output):
scale = ctx.scale
if len(grad_output.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(grad_output.shape) == 2:
scale = scale.view(-1, 1)
else:
scale = scale.view(-1)
return grad_output.clone() / scale, None, None, None, None
class floor_ste(Function):
"""
Straight-through Estimator(STE) for torch.floor()
"""
@staticmethod
def forward(ctx, x):
return torch.floor(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
class round_ste(Function):
"""
Straight-through Estimator(STE) for torch.round()
"""
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
def batch_frexp(inputs, max_bit=31):
"""
Decompose the scaling factor into mantissa and twos exponent.
Args:
scaling_factor (:obj:`torch.Tensor`):
Target scaling factor to decompose.
Returns:
:obj:``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
"""
shape_of_input = inputs.size()
# trans the input to be a 1-d tensor
inputs = inputs.view(-1)
output_m, output_e = np.frexp(inputs.cpu().numpy())
tmp_m = []
for m in output_m:
int_m_shifted = int(
decimal.Decimal(m * (2 ** max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP)
)
tmp_m.append(int_m_shifted)
output_m = np.array(tmp_m)
output_e = float(max_bit) - output_e
return (
torch.from_numpy(output_m).to(inputs.device).view(shape_of_input),
torch.from_numpy(output_e).to(inputs.device).view(shape_of_input),
)
class FixedPointMul(Function):
"""
Function to perform fixed-point arthmetic that can match integer arthmetic on hardware.
Args:
pre_act (:obj:`torch.Tensor`):
Input tensor.
pre_act_scaling_factor (:obj:`torch.Tensor`):
Scaling factor of the input tensor `pre_act`.
bit_num (:obj:`int`):
Quantization bitwidth.
z_scaling_factor (:obj:`torch.Tensor`):
Scaling factor of the output tensor.
identity (:obj:`torch.Tensor`, `optional`, defaults to :obj:`None`):
Identity tensor, if exists.
identity_scaling_factor (:obj:`torch.Tensor`, `optional`, defaults to :obj:`None`):
Scaling factor of the identity tensor `identity`, if exists.
Returns:
:obj:`torch.Tensor`: Output tensor(`pre_act` if `identity` is not given, otherwise the addition of `pre_act`
and `identity`), whose scale is rescaled to `z_scaling_factor`.
"""
@staticmethod
def forward(
ctx,
pre_act,
pre_act_scaling_factor,
bit_num,
z_scaling_factor,
identity=None,
identity_scaling_factor=None,
):
if len(pre_act_scaling_factor.shape) == 3:
reshape = lambda x: x # noqa: E731
else:
reshape = lambda x: x.view(1, 1, -1) # noqa: E731
ctx.identity = identity
n = 2 ** (bit_num - 1) - 1
with torch.no_grad():
pre_act_scaling_factor = reshape(pre_act_scaling_factor)
if identity is not None:
identity_scaling_factor = reshape(identity_scaling_factor)
ctx.z_scaling_factor = z_scaling_factor
z_int = torch.round(pre_act / pre_act_scaling_factor)
_A = pre_act_scaling_factor.type(torch.double)
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
new_scale = _A / _B
new_scale = reshape(new_scale)
m, e = batch_frexp(new_scale)
output = z_int.type(torch.double) * m.type(torch.double)
output = torch.round(output / (2.0 ** e))
if identity is not None:
# needs addition of identity activation
wx_int = torch.round(identity / identity_scaling_factor)
_A = identity_scaling_factor.type(torch.double)
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
new_scale = _A / _B
new_scale = reshape(new_scale)
m1, e1 = batch_frexp(new_scale)
output1 = wx_int.type(torch.double) * m1.type(torch.double)
output1 = torch.round(output1 / (2.0 ** e1))
output = output1 + output
return torch.clamp(output.type(torch.float), -n - 1, n)
@staticmethod
def backward(ctx, grad_output):
identity_grad = None
if ctx.identity is not None:
identity_grad = grad_output.clone() / ctx.z_scaling_factor
return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None