Add ONNX support for Blenderbot and BlenderbotSmall (#15875)
* Add ONNX support for Blenderbot * Add BlenderbotSmall ONNX configuration * Update serialization table
This commit is contained in:
@@ -48,6 +48,8 @@ Ready-made configurations include the following architectures:
|
|||||||
- ALBERT
|
- ALBERT
|
||||||
- BART
|
- BART
|
||||||
- BERT
|
- BERT
|
||||||
|
- Blenderbot
|
||||||
|
- BlenderbotSmall
|
||||||
- CamemBERT
|
- CamemBERT
|
||||||
- Data2VecText
|
- Data2VecText
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig"],
|
"configuration_blenderbot": [
|
||||||
|
"BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"BlenderbotConfig",
|
||||||
|
"BlenderbotOnnxConfig",
|
||||||
|
],
|
||||||
"tokenization_blenderbot": ["BlenderbotTokenizer"],
|
"tokenization_blenderbot": ["BlenderbotTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +60,11 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
|
from .configuration_blenderbot import (
|
||||||
|
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
BlenderbotConfig,
|
||||||
|
BlenderbotOnnxConfig,
|
||||||
|
)
|
||||||
from .tokenization_blenderbot import BlenderbotTokenizer
|
from .tokenization_blenderbot import BlenderbotTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,7 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Blenderbot model configuration"""
|
""" Blenderbot model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from ... import PreTrainedTokenizer
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...file_utils import TensorType, is_torch_available
|
||||||
|
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
|
||||||
|
from ...onnx.utils import compute_effective_axis_dimension
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -164,3 +171,229 @@ class BlenderbotConfig(PretrainedConfig):
|
|||||||
forced_eos_token_id=forced_eos_token_id,
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderbotOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.use_past:
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
if self.use_past:
|
||||||
|
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
elif self.task == "causal-lm":
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.use_past:
|
||||||
|
_, num_decoder_layers = self.num_layers
|
||||||
|
for i in range(num_decoder_layers):
|
||||||
|
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
|
||||||
|
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.outputs
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
common_outputs = super().outputs
|
||||||
|
else:
|
||||||
|
common_outputs = super(OnnxConfigWithPast, self).outputs
|
||||||
|
if self.use_past:
|
||||||
|
num_encoder_layers, _ = self.num_layers
|
||||||
|
for i in range(num_encoder_layers):
|
||||||
|
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
return common_outputs
|
||||||
|
|
||||||
|
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
# Generate decoder inputs
|
||||||
|
decoder_seq_length = seq_length if not self.use_past else 1
|
||||||
|
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||||
|
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
||||||
|
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
||||||
|
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
||||||
|
encoder_shape = (
|
||||||
|
batch,
|
||||||
|
num_encoder_attention_heads,
|
||||||
|
encoder_seq_length,
|
||||||
|
self._config.hidden_size // num_encoder_attention_heads,
|
||||||
|
)
|
||||||
|
decoder_past_length = decoder_seq_length
|
||||||
|
decoder_shape = (
|
||||||
|
batch,
|
||||||
|
num_decoder_attention_heads,
|
||||||
|
decoder_past_length,
|
||||||
|
self._config.hidden_size // num_decoder_attention_heads,
|
||||||
|
)
|
||||||
|
common_inputs["decoder_attention_mask"] = torch.cat(
|
||||||
|
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
|
||||||
|
)
|
||||||
|
common_inputs["past_key_values"] = []
|
||||||
|
_, num_decoder_layers = self.num_layers
|
||||||
|
|
||||||
|
for _ in range(num_decoder_layers):
|
||||||
|
common_inputs["past_key_values"].append(
|
||||||
|
(
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
def _generate_dummy_inputs_for_causal_lm(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
batch, seqlen = common_inputs["input_ids"].shape
|
||||||
|
past_key_values_length = seqlen
|
||||||
|
_, num_decoder_layers = self.num_layers
|
||||||
|
num_encoder_attention_heads, _ = self.num_attention_heads
|
||||||
|
past_shape = (
|
||||||
|
batch,
|
||||||
|
num_encoder_attention_heads,
|
||||||
|
past_key_values_length,
|
||||||
|
self._config.hidden_size // num_encoder_attention_heads,
|
||||||
|
)
|
||||||
|
common_inputs["attention_mask"] = torch.cat(
|
||||||
|
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||||
|
)
|
||||||
|
common_inputs["past_key_values"] = [
|
||||||
|
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_decoder_layers)
|
||||||
|
]
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
|
||||||
|
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
# Copied from OnnxConfig.generate_dummy_inputs
|
||||||
|
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(
|
||||||
|
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
|
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||||
|
seq_length = compute_effective_axis_dimension(
|
||||||
|
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||||
|
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.generate_dummy_inputs
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.task == "causal-lm":
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_causal_lm(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._flatten_past_key_values_
|
||||||
|
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
|
||||||
|
else:
|
||||||
|
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
|
||||||
|
flattened_output, name, idx, t
|
||||||
|
)
|
||||||
|
|
||||||
|
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
|
||||||
|
if direction not in ["inputs", "outputs"]:
|
||||||
|
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
|
||||||
|
|
||||||
|
name = "past_key_values" if direction == "inputs" else "present"
|
||||||
|
_, num_decoder_layers = self.num_layers
|
||||||
|
|
||||||
|
encoder_sequence = "past_encoder_sequence"
|
||||||
|
decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence"
|
||||||
|
|
||||||
|
for i in range(num_decoder_layers):
|
||||||
|
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence}
|
||||||
|
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence}
|
||||||
|
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence}
|
||||||
|
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence}
|
||||||
|
|||||||
@@ -21,7 +21,11 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_blenderbot_small": ["BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotSmallConfig"],
|
"configuration_blenderbot_small": [
|
||||||
|
"BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"BlenderbotSmallConfig",
|
||||||
|
"BlenderbotSmallOnnxConfig",
|
||||||
|
],
|
||||||
"tokenization_blenderbot_small": ["BlenderbotSmallTokenizer"],
|
"tokenization_blenderbot_small": ["BlenderbotSmallTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +56,11 @@ if is_flax_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig
|
from .configuration_blenderbot_small import (
|
||||||
|
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
BlenderbotSmallConfig,
|
||||||
|
BlenderbotSmallOnnxConfig,
|
||||||
|
)
|
||||||
from .tokenization_blenderbot_small import BlenderbotSmallTokenizer
|
from .tokenization_blenderbot_small import BlenderbotSmallTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,7 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BlenderbotSmall model configuration"""
|
""" BlenderbotSmall model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from ... import PreTrainedTokenizer
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...file_utils import TensorType, is_torch_available
|
||||||
|
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
|
||||||
|
from ...onnx.utils import compute_effective_axis_dimension
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -162,3 +169,226 @@ class BlenderbotSmallConfig(PretrainedConfig):
|
|||||||
forced_eos_token_id=forced_eos_token_id,
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig
|
||||||
|
class BlenderbotSmallOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
elif self.task == "causal-lm":
|
||||||
|
# TODO: figure this case out.
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.use_past:
|
||||||
|
num_encoder_layers, _ = self.num_layers
|
||||||
|
for i in range(num_encoder_layers):
|
||||||
|
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
|
||||||
|
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
common_outputs = super().outputs
|
||||||
|
else:
|
||||||
|
common_outputs = super(OnnxConfigWithPast, self).outputs
|
||||||
|
if self.use_past:
|
||||||
|
num_encoder_layers, _ = self.num_layers
|
||||||
|
for i in range(num_encoder_layers):
|
||||||
|
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||||
|
return common_outputs
|
||||||
|
|
||||||
|
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate decoder inputs
|
||||||
|
decoder_seq_length = seq_length if not self.use_past else 1
|
||||||
|
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||||
|
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
||||||
|
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
||||||
|
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
||||||
|
encoder_shape = (
|
||||||
|
batch,
|
||||||
|
num_encoder_attention_heads,
|
||||||
|
encoder_seq_length,
|
||||||
|
self._config.hidden_size // num_encoder_attention_heads,
|
||||||
|
)
|
||||||
|
decoder_past_length = decoder_seq_length + 3
|
||||||
|
decoder_shape = (
|
||||||
|
batch,
|
||||||
|
num_decoder_attention_heads,
|
||||||
|
decoder_past_length,
|
||||||
|
self._config.hidden_size // num_decoder_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
common_inputs["decoder_attention_mask"] = torch.cat(
|
||||||
|
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
common_inputs["past_key_values"] = []
|
||||||
|
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
||||||
|
num_encoder_layers, num_decoder_layers = self.num_layers
|
||||||
|
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
||||||
|
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
||||||
|
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
||||||
|
|
||||||
|
for _ in range(min_num_layers):
|
||||||
|
common_inputs["past_key_values"].append(
|
||||||
|
(
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# TODO: test this.
|
||||||
|
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
||||||
|
for _ in range(min_num_layers, max_num_layers):
|
||||||
|
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
def _generate_dummy_inputs_for_causal_lm(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
batch, seqlen = common_inputs["input_ids"].shape
|
||||||
|
# Not using the same length for past_key_values
|
||||||
|
past_key_values_length = seqlen + 2
|
||||||
|
num_encoder_layers, _ = self.num_layers
|
||||||
|
num_encoder_attention_heads, _ = self.num_attention_heads
|
||||||
|
past_shape = (
|
||||||
|
batch,
|
||||||
|
num_encoder_attention_heads,
|
||||||
|
past_key_values_length,
|
||||||
|
self._config.hidden_size // num_encoder_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
common_inputs["attention_mask"] = torch.cat(
|
||||||
|
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||||
|
)
|
||||||
|
common_inputs["past_key_values"] = [
|
||||||
|
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
|
||||||
|
]
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
# Copied from OnnxConfig.generate_dummy_inputs
|
||||||
|
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(
|
||||||
|
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
|
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||||
|
seq_length = compute_effective_axis_dimension(
|
||||||
|
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||||
|
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.task == "causal-lm":
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_causal_lm(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||||
|
if self.task in ["default", "seq2seq-lm"]:
|
||||||
|
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
|
||||||
|
else:
|
||||||
|
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
|
||||||
|
flattened_output, name, idx, t
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_avail
|
|||||||
from ..models.albert import AlbertOnnxConfig
|
from ..models.albert import AlbertOnnxConfig
|
||||||
from ..models.bart import BartOnnxConfig
|
from ..models.bart import BartOnnxConfig
|
||||||
from ..models.bert import BertOnnxConfig
|
from ..models.bert import BertOnnxConfig
|
||||||
|
from ..models.blenderbot import BlenderbotOnnxConfig
|
||||||
|
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
||||||
from ..models.camembert import CamembertOnnxConfig
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.electra import ElectraOnnxConfig
|
from ..models.electra import ElectraOnnxConfig
|
||||||
@@ -268,6 +270,24 @@ class FeaturesManager:
|
|||||||
onnx_config_cls=ElectraOnnxConfig,
|
onnx_config_cls=ElectraOnnxConfig,
|
||||||
),
|
),
|
||||||
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
|
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
|
||||||
|
"blenderbot": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"default-with-past",
|
||||||
|
"causal-lm",
|
||||||
|
"causal-lm-with-past",
|
||||||
|
"seq2seq-lm",
|
||||||
|
"seq2seq-lm-with-past",
|
||||||
|
onnx_config_cls=BlenderbotOnnxConfig,
|
||||||
|
),
|
||||||
|
"blenderbot-small": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"default-with-past",
|
||||||
|
"causal-lm",
|
||||||
|
"causal-lm-with-past",
|
||||||
|
"seq2seq-lm",
|
||||||
|
"seq2seq-lm-with-past",
|
||||||
|
onnx_config_cls=BlenderbotSmallOnnxConfig,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
||||||
|
|||||||
@@ -194,6 +194,8 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
|||||||
("t5", "t5-small"),
|
("t5", "t5-small"),
|
||||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||||
("m2m-100", "facebook/m2m100_418M"),
|
("m2m-100", "facebook/m2m100_418M"),
|
||||||
|
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
||||||
|
("blenderbot", "facebook/blenderbot-400M-distill"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||||
|
|||||||
Reference in New Issue
Block a user