M2M100 support for ONNX export (#15193)
* Add M2M100 support for ONNX export * Delete useless imports * Add M2M100 to tests * Fix protobuf issue
This commit is contained in:
@@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- GPT Neo
|
- GPT Neo
|
||||||
- I-BERT
|
- I-BERT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
|
- M2M100
|
||||||
- Marian
|
- Marian
|
||||||
- mBART
|
- mBART
|
||||||
- OpenAI GPT-2
|
- OpenAI GPT-2
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_availab
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
"configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config", "M2M100OnnxConfig"],
|
||||||
"tokenization_m2m_100": ["M2M100Tokenizer"],
|
"tokenization_m2m_100": ["M2M100Tokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
|
from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config, M2M100OnnxConfig
|
||||||
from .tokenization_m2m_100 import M2M100Tokenizer
|
from .tokenization_m2m_100 import M2M100Tokenizer
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|||||||
@@ -13,8 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" M2M100 model configuration"""
|
""" M2M100 model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from ... import PreTrainedTokenizer
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...file_utils import TensorType, is_torch_available
|
||||||
|
from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
|
||||||
|
from ...onnx.utils import compute_effective_axis_dimension
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -153,3 +159,126 @@ class M2M100Config(PretrainedConfig):
|
|||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class M2M100OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
# Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
|
||||||
|
# A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
|
||||||
|
# answering are not supported for M2M100, but this name is preserved to be able to check that the copy matches what
|
||||||
|
# was done for BART so that it can be updated if need be.
|
||||||
|
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
# Copied from OnnxConfig.generate_dummy_inputs
|
||||||
|
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(
|
||||||
|
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
|
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||||
|
seq_length = compute_effective_axis_dimension(
|
||||||
|
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||||
|
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
|
||||||
|
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate decoder inputs
|
||||||
|
decoder_seq_length = seq_length if not self.use_past else 1
|
||||||
|
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||||
|
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||||
|
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
||||||
|
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
||||||
|
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
||||||
|
encoder_shape = (
|
||||||
|
batch,
|
||||||
|
num_encoder_attention_heads,
|
||||||
|
encoder_seq_length,
|
||||||
|
self._config.hidden_size // num_encoder_attention_heads,
|
||||||
|
)
|
||||||
|
decoder_past_length = decoder_seq_length + 3
|
||||||
|
decoder_shape = (
|
||||||
|
batch,
|
||||||
|
num_decoder_attention_heads,
|
||||||
|
decoder_past_length,
|
||||||
|
self._config.hidden_size // num_decoder_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
common_inputs["decoder_attention_mask"] = torch.cat(
|
||||||
|
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
common_inputs["past_key_values"] = []
|
||||||
|
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
||||||
|
num_encoder_layers, num_decoder_layers = self.num_layers
|
||||||
|
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
||||||
|
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
||||||
|
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
||||||
|
|
||||||
|
for _ in range(min_num_layers):
|
||||||
|
common_inputs["past_key_values"].append(
|
||||||
|
(
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# TODO: test this.
|
||||||
|
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
||||||
|
for _ in range(min_num_layers, max_num_layers):
|
||||||
|
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
|
||||||
|
|||||||
@@ -117,21 +117,34 @@ def export_pytorch(
|
|||||||
|
|
||||||
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
||||||
# so we check the torch version for backwards compatibility
|
# so we check the torch version for backwards compatibility
|
||||||
if parse(torch.__version__) <= parse("1.10.99"):
|
if parse(torch.__version__) < parse("1.10"):
|
||||||
# export can work with named args but the dict containing named args
|
# export can work with named args but the dict containing named args
|
||||||
# has to be the last element of the args tuple.
|
# has to be the last element of the args tuple.
|
||||||
onnx_export(
|
try:
|
||||||
model,
|
onnx_export(
|
||||||
(model_inputs,),
|
model,
|
||||||
f=output.as_posix(),
|
(model_inputs,),
|
||||||
input_names=list(config.inputs.keys()),
|
f=output.as_posix(),
|
||||||
output_names=onnx_outputs,
|
input_names=list(config.inputs.keys()),
|
||||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
output_names=onnx_outputs,
|
||||||
do_constant_folding=True,
|
dynamic_axes={
|
||||||
use_external_data_format=config.use_external_data_format(model.num_parameters()),
|
name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())
|
||||||
enable_onnx_checker=True,
|
},
|
||||||
opset_version=opset,
|
do_constant_folding=True,
|
||||||
)
|
use_external_data_format=config.use_external_data_format(model.num_parameters()),
|
||||||
|
enable_onnx_checker=True,
|
||||||
|
opset_version=opset,
|
||||||
|
)
|
||||||
|
except RuntimeError as err:
|
||||||
|
message = str(err)
|
||||||
|
if (
|
||||||
|
message
|
||||||
|
== "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter."
|
||||||
|
):
|
||||||
|
message = "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter or try with torch 1.10+."
|
||||||
|
raise RuntimeError(message)
|
||||||
|
else:
|
||||||
|
raise err
|
||||||
else:
|
else:
|
||||||
onnx_export(
|
onnx_export(
|
||||||
model,
|
model,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from ..models.gpt2 import GPT2OnnxConfig
|
|||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
from ..models.ibert import IBertOnnxConfig
|
from ..models.ibert import IBertOnnxConfig
|
||||||
from ..models.layoutlm import LayoutLMOnnxConfig
|
from ..models.layoutlm import LayoutLMOnnxConfig
|
||||||
|
from ..models.m2m_100 import M2M100OnnxConfig
|
||||||
from ..models.marian import MarianOnnxConfig
|
from ..models.marian import MarianOnnxConfig
|
||||||
from ..models.mbart import MBartOnnxConfig
|
from ..models.mbart import MBartOnnxConfig
|
||||||
from ..models.roberta import RobertaOnnxConfig
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
@@ -184,6 +185,9 @@ class FeaturesManager:
|
|||||||
"causal-lm-with-past",
|
"causal-lm-with-past",
|
||||||
onnx_config_cls=MarianOnnxConfig,
|
onnx_config_cls=MarianOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"m2m-100": supported_features_mapping(
|
||||||
|
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
|
||||||
|
),
|
||||||
"roberta": supported_features_mapping(
|
"roberta": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
|||||||
("mbart", "sshleifer/tiny-mbart"),
|
("mbart", "sshleifer/tiny-mbart"),
|
||||||
("t5", "t5-small"),
|
("t5", "t5-small"),
|
||||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||||
|
("m2m-100", "facebook/m2m100_418M"),
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
||||||
|
|||||||
Reference in New Issue
Block a user