add ONNX support for BLOOM (#17961)
* add onnx support for BLOOM * use TYPE_CHECKING for type annotations * fix past_shape for bloom (different from gpt2) * use logical_or instead of `+` for onnx support * bigger `atol_for_validation` for larger bloom models * copied -> taken because it's no longer an exact copy * remove "copied from" comment Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- BigBird-Pegasus
|
- BigBird-Pegasus
|
||||||
- Blenderbot
|
- Blenderbot
|
||||||
- BlenderbotSmall
|
- BlenderbotSmall
|
||||||
|
- BLOOM
|
||||||
- CamemBERT
|
- CamemBERT
|
||||||
- CodeGen
|
- CodeGen
|
||||||
- ConvBERT
|
- ConvBERT
|
||||||
|
|||||||
@@ -22,10 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_bloom": [
|
"configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"],
|
||||||
"BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
|
||||||
"BloomConfig",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
if not is_tokenizers_available():
|
if not is_tokenizers_available():
|
||||||
@@ -51,7 +48,7 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
|
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_tokenizers_available():
|
if not is_tokenizers_available():
|
||||||
|
|||||||
@@ -13,7 +13,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Bloom configuration"""
|
""" Bloom configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer, TensorType
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -153,3 +163,88 @@ class BloomConfig(PretrainedConfig):
|
|||||||
self.slow_but_exact = slow_but_exact
|
self.slow_but_exact = slow_but_exact
|
||||||
|
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BloomOnnxConfig(OnnxConfigWithPast):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
task: str = "default",
|
||||||
|
patching_specs: List[PatchingSpec] = None,
|
||||||
|
use_past: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||||
|
if not getattr(self._config, "pad_token_id", None):
|
||||||
|
# TODO: how to do that better?
|
||||||
|
self._config.pad_token_id = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||||
|
if self.use_past:
|
||||||
|
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_layers(self) -> int:
|
||||||
|
return self._config.n_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_attention_heads(self) -> int:
|
||||||
|
return self._config.n_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-3
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional["TensorType"] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to order the input in the way they appears in the forward()
|
||||||
|
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||||
|
|
||||||
|
# Need to add the past_keys
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch, seqlen = common_inputs["input_ids"].shape
|
||||||
|
# Not using the same length for past_key_values
|
||||||
|
past_key_values_length = seqlen + 2
|
||||||
|
past_shape = (
|
||||||
|
batch,
|
||||||
|
past_key_values_length,
|
||||||
|
self.num_attention_heads,
|
||||||
|
self._config.hidden_size // self.num_attention_heads,
|
||||||
|
)
|
||||||
|
ordered_inputs["past_key_values"] = [
|
||||||
|
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||||
|
if self.use_past:
|
||||||
|
mask_dtype = ordered_inputs["attention_mask"].dtype
|
||||||
|
ordered_inputs["attention_mask"] = torch.cat(
|
||||||
|
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
return ordered_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 13
|
||||||
|
|||||||
@@ -78,17 +78,14 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
|
|||||||
|
|
||||||
|
|
||||||
def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
||||||
if attention_mask.dtype == torch.bool:
|
attention_mask_bool = ~attention_mask.bool()
|
||||||
attention_mask_bool = ~attention_mask
|
|
||||||
else:
|
|
||||||
attention_mask_bool = (1 - attention_mask).bool()
|
|
||||||
|
|
||||||
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
|
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
|
||||||
padded_causal_mask = (
|
padded_causal_mask = torch.logical_or(
|
||||||
attention_mask_bool[:, None, key_length - query_length : key_length, None]
|
attention_mask_bool[:, None, key_length - query_length : key_length, None],
|
||||||
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(),
|
||||||
).bool()
|
)
|
||||||
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
|
padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length])
|
||||||
# Make use of floats
|
# Make use of floats
|
||||||
return (
|
return (
|
||||||
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
|
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
|
||||||
@@ -296,11 +293,8 @@ class BloomScaledSoftmax(nn.Module):
|
|||||||
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
||||||
|
|
||||||
mask = mask.to(input.device)
|
mask = mask.to(input.device)
|
||||||
causal_mask = (
|
seq_ids = torch.arange(max_positions, device=input.device)
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
|
||||||
.view(1, 1, max_positions, max_positions)
|
|
||||||
.to(input.device)
|
|
||||||
)
|
|
||||||
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
||||||
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
||||||
|
|
||||||
|
|||||||
@@ -182,6 +182,15 @@ class FeaturesManager:
|
|||||||
"seq2seq-lm-with-past",
|
"seq2seq-lm-with-past",
|
||||||
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
|
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"bloom": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"default-with-past",
|
||||||
|
"causal-lm",
|
||||||
|
"causal-lm-with-past",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
onnx_config_cls="models.bloom.BloomOnnxConfig",
|
||||||
|
),
|
||||||
"camembert": supported_features_mapping(
|
"camembert": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
|
("bloom", "bigscience/bloom-350m"),
|
||||||
("gpt2", "gpt2"),
|
("gpt2", "gpt2"),
|
||||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user