From 6b655cc63fd7e8ec120d9c38321a286d04767db0 Mon Sep 17 00:00:00 2001 From: lewtun Date: Thu, 23 Dec 2021 13:35:56 +0100 Subject: [PATCH] Add ONNX support for MarianMT models (#14586) * First commit to add MarianMT to ONNX * Now MarianModel.forward() automatically generates decoder_input_ids, like BartModel.forward() * Adjusted MarianOnnxConfig.inputs and outputs to work with seq2seq-lm feature * Style fix * Added support for other features for already supported models * Partial support for causal and seq2seq models * Partial support for causal and seq2seq models * Add default task for MarianMT ONNX * Remove automatic creation of decoder_input_ids * Extend inputs and outputs for MarianMT ONNX config * Add MarianMT to ONNX unit tests * Refactor * OnnxSeq2SeqConfigWithPast to support seq2seq models * Parameterized the onnx tests * Restored run_mlm.py * Restored run_mlm.py * [WIP] BART update * BART and MBART * Add past_key_values and fix dummy decoder inputs Using a sequence length of 1 in generate_dummy_outputs() produces large discrepancies, presumably due to some hidden optimisations. * Refactor MarianOnnxConfig to remove custom past_key_values logic * Fix quality * Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c5599523c110cd713f60a3bfa145dad807. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Refactor Marian export to account for base changes * Fix copies * Implemented suggestions * Extend support for causal LM * Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c5599523c110cd713f60a3bfa145dad807. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c5599523c110cd713f60a3bfa145dad807. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Remove commented import * Remove ONNX model * Remove redundant class method * Tidy up imports * Fix quality * Refactor dummy input function * Add copied from statements to Marian config functions * Remove false copied from comments * Fix copy from comment Co-authored-by: Massimiliano Bruni Co-authored-by: Michael Benayoun --- docs/source/serialization.mdx | 1 + src/transformers/models/marian/__init__.py | 4 +- .../models/marian/configuration_marian.py | 229 ++++++++++++++++++ .../models/marian/tokenization_marian.py | 2 +- src/transformers/onnx/features.py | 10 + tests/test_onnx_v2.py | 1 + 6 files changed, 244 insertions(+), 3 deletions(-) diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index a78bda68c2..66d0933b04 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -42,6 +42,7 @@ Ready-made configurations include the following models: - GPT Neo - LayoutLM - Longformer +- Marian - mBART - OpenAI GPT-2 - RoBERTa diff --git a/src/transformers/models/marian/__init__.py b/src/transformers/models/marian/__init__.py index 7e6ce7625e..3348965a80 100644 --- a/src/transformers/models/marian/__init__.py +++ b/src/transformers/models/marian/__init__.py @@ -28,7 +28,7 @@ from ...file_utils import ( _import_structure = { - "configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig"], + "configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig", "MarianOnnxConfig"], } if is_sentencepiece_available(): @@ -49,7 +49,7 @@ if is_tf_available(): if is_flax_available(): _import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"] if TYPE_CHECKING: - from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig + from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig if is_sentencepiece_available(): from .tokenization_marian import MarianTokenizer diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index a0be3c7723..6ee95889b6 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -13,8 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Marian model configuration """ +from collections import OrderedDict +from typing import Any, Mapping, Optional +from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig +from ...file_utils import TensorType, is_torch_available +from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx.utils import compute_effective_axis_dimension from ...utils import logging @@ -160,3 +166,226 @@ class MarianConfig(PretrainedConfig): forced_eos_token_id=forced_eos_token_id, **kwargs, ) + + +class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.inputs + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + elif self.task == "causal-lm": + # TODO: figure this case out. + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + else: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), + ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ] + ) + + return common_inputs + + @property + # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.outputs + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_outputs = super().outputs + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + return common_outputs + + def _generate_dummy_inputs_for_default_and_seq2seq_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + encoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, encoder_seq_length = common_inputs["input_ids"].shape + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_past_length = decoder_seq_length + 3 + decoder_shape = ( + batch, + num_decoder_attention_heads, + decoder_past_length, + self._config.hidden_size // num_decoder_attention_heads, + ) + + common_inputs["decoder_attention_mask"] = torch.cat( + [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 + ) + + common_inputs["past_key_values"] = [] + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + for _ in range(min_num_layers): + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + # TODO: test this. + shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape + for _ in range(min_num_layers, max_num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + return common_inputs + + def _generate_dummy_inputs_for_causal_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = self._generate_dummy_inputs_for_encoder_and_decoder( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + num_encoder_layers, _ = self.num_layers + num_encoder_attention_heads, _ = self.num_attention_heads + past_shape = ( + batch, + num_encoder_attention_heads, + past_key_values_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + ) + common_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) + ] + return common_inputs + + # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering + # We renamed this function because Marian models do not have a sequence classification or question answering head + def _generate_dummy_inputs_for_encoder_and_decoder( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # Copied from OnnxConfig.generate_dummy_inputs + # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + return common_inputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + else: + common_inputs = self._generate_dummy_inputs_for_causal_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + return common_inputs + + # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._flatten_past_key_values_ + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + if self.task in ["default", "seq2seq-lm"]: + flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( + flattened_output, name, idx, t + ) diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index 5022569b8d..3ec362565c 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -310,7 +310,7 @@ class MarianTokenizer(PreTrainedTokenizer): self.current_spm = self.spm_source self._setup_normalizer() - def num_special_tokens_to_add(self, **unused): + def num_special_tokens_to_add(self, *args, **kwargs): """Just EOS""" return 1 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index b1d0f10453..2f12a574bf 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -11,6 +11,7 @@ from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.layoutlm import LayoutLMOnnxConfig from ..models.longformer import LongformerOnnxConfig +from ..models.marian import MarianOnnxConfig from ..models.mbart import MBartOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.t5 import T5OnnxConfig @@ -152,6 +153,15 @@ class FeaturesManager: "question-answering", onnx_config_cls=LongformerOnnxConfig, ), + "marian": supported_features_mapping( + "default", + "default-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + "causal-lm", + "causal-lm-with-past", + onnx_config_cls=MarianOnnxConfig, + ), "roberta": supported_features_mapping( "default", "masked-lm", diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index 27cc10e9fa..fdd7694c41 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -188,6 +188,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ("bart", "facebook/bart-base"), ("mbart", "sshleifer/tiny-mbart"), ("t5", "t5-small"), + ("marian", "Helsinki-NLP/opus-mt-en-de"), }