add ONNX support for BLOOM (#17961)
* add onnx support for BLOOM * use TYPE_CHECKING for type annotations * fix past_shape for bloom (different from gpt2) * use logical_or instead of `+` for onnx support * bigger `atol_for_validation` for larger bloom models * copied -> taken because it's no longer an exact copy * remove "copied from" comment Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
|
||||
- BigBird-Pegasus
|
||||
- Blenderbot
|
||||
- BlenderbotSmall
|
||||
- BLOOM
|
||||
- CamemBERT
|
||||
- CodeGen
|
||||
- ConvBERT
|
||||
|
||||
@@ -22,10 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_bloom": [
|
||||
"BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"BloomConfig",
|
||||
],
|
||||
"configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"],
|
||||
}
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
@@ -51,7 +48,7 @@ else:
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
|
||||
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig
|
||||
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
|
||||
@@ -13,7 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Bloom configuration"""
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -153,3 +163,88 @@ class BloomConfig(PretrainedConfig):
|
||||
self.slow_but_exact = slow_but_exact
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
|
||||
class BloomOnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: str = "default",
|
||||
patching_specs: List[PatchingSpec] = None,
|
||||
use_past: bool = False,
|
||||
):
|
||||
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||
if not getattr(self._config, "pad_token_id", None):
|
||||
# TODO: how to do that better?
|
||||
self._config.pad_token_id = 0
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self._config.n_layer
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self._config.n_head
|
||||
|
||||
@property
|
||||
def atol_for_validation(self) -> float:
|
||||
return 1e-3
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional["TensorType"] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
|
||||
# Need to add the past_keys
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_shape = (
|
||||
batch,
|
||||
past_key_values_length,
|
||||
self.num_attention_heads,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
mask_dtype = ordered_inputs["attention_mask"].dtype
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
||||
|
||||
@@ -78,17 +78,14 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
||||
if attention_mask.dtype == torch.bool:
|
||||
attention_mask_bool = ~attention_mask
|
||||
else:
|
||||
attention_mask_bool = (1 - attention_mask).bool()
|
||||
attention_mask_bool = ~attention_mask.bool()
|
||||
|
||||
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
|
||||
padded_causal_mask = (
|
||||
attention_mask_bool[:, None, key_length - query_length : key_length, None]
|
||||
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
||||
).bool()
|
||||
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
|
||||
padded_causal_mask = torch.logical_or(
|
||||
attention_mask_bool[:, None, key_length - query_length : key_length, None],
|
||||
~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(),
|
||||
)
|
||||
padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length])
|
||||
# Make use of floats
|
||||
return (
|
||||
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
|
||||
@@ -296,11 +293,8 @@ class BloomScaledSoftmax(nn.Module):
|
||||
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
||||
|
||||
mask = mask.to(input.device)
|
||||
causal_mask = (
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
||||
.view(1, 1, max_positions, max_positions)
|
||||
.to(input.device)
|
||||
)
|
||||
seq_ids = torch.arange(max_positions, device=input.device)
|
||||
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
|
||||
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
||||
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
||||
|
||||
|
||||
@@ -182,6 +182,15 @@ class FeaturesManager:
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
|
||||
),
|
||||
"bloom": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
onnx_config_cls="models.bloom.BloomOnnxConfig",
|
||||
),
|
||||
"camembert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
|
||||
@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = {
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
("bloom", "bigscience/bloom-350m"),
|
||||
("gpt2", "gpt2"),
|
||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user