Add mt5 onnx config (#18394)
* update features * MT5OnnxConfig added with updated with tests and docs * fix imports * fix onnc_config_cls for mt5 Co-authored-by: Thomas Chaigneau <thomas.deeptools.ai>
This commit is contained in:
@@ -79,6 +79,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- mBART
|
- mBART
|
||||||
- MobileBERT
|
- MobileBERT
|
||||||
- MobileViT
|
- MobileViT
|
||||||
|
- MT5
|
||||||
- OpenAI GPT-2
|
- OpenAI GPT-2
|
||||||
- Perceiver
|
- Perceiver
|
||||||
- PLBart
|
- PLBart
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ else:
|
|||||||
|
|
||||||
MT5TokenizerFast = T5TokenizerFast
|
MT5TokenizerFast = T5TokenizerFast
|
||||||
|
|
||||||
_import_structure = {"configuration_mt5": ["MT5Config"]}
|
_import_structure = {"configuration_mt5": ["MT5Config", "MT5OnnxConfig"]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
@@ -71,7 +71,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_mt5 import MT5Config
|
from .configuration_mt5 import MT5Config, MT5OnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
|
|||||||
@@ -13,8 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" mT5 model configuration"""
|
""" mT5 model configuration"""
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxSeq2SeqConfigWithPast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -143,3 +145,29 @@ class MT5Config(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def num_hidden_layers(self):
|
def num_hidden_layers(self):
|
||||||
return self.num_layers
|
return self.num_layers
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig
|
||||||
|
class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = {
|
||||||
|
"input_ids": {0: "batch", 1: "encoder_sequence"},
|
||||||
|
"attention_mask": {0: "batch", 1: "encoder_sequence"},
|
||||||
|
}
|
||||||
|
if self.use_past:
|
||||||
|
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 13
|
||||||
|
|||||||
@@ -383,6 +383,13 @@ class FeaturesManager:
|
|||||||
"image-classification",
|
"image-classification",
|
||||||
onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
|
onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"mt5": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"default-with-past",
|
||||||
|
"seq2seq-lm",
|
||||||
|
"seq2seq-lm-with-past",
|
||||||
|
onnx_config_cls="models.mt5.MT5OnnxConfig",
|
||||||
|
),
|
||||||
"m2m-100": supported_features_mapping(
|
"m2m-100": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
|||||||
@@ -224,6 +224,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
|||||||
("mbart", "sshleifer/tiny-mbart"),
|
("mbart", "sshleifer/tiny-mbart"),
|
||||||
("t5", "t5-small"),
|
("t5", "t5-small"),
|
||||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||||
|
("mt5", "google/mt5-base"),
|
||||||
("m2m-100", "facebook/m2m100_418M"),
|
("m2m-100", "facebook/m2m100_418M"),
|
||||||
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
||||||
("blenderbot", "facebook/blenderbot-400M-distill"),
|
("blenderbot", "facebook/blenderbot-400M-distill"),
|
||||||
|
|||||||
Reference in New Issue
Block a user