From 029b0d95ed7282697bd2b361e95e6f6144151ae2 Mon Sep 17 00:00:00 2001 From: Thomas Chaigneau <50595514+ChainYo@users.noreply.github.com> Date: Wed, 23 Mar 2022 21:36:11 +0100 Subject: [PATCH] add GPT-J ONNX config to Transformers (#16274) * add GPT-J ONNX config to Transformers * remove token-classification features mapping Co-authored-by: lewtun * add question-answering features mapping Co-authored-by: lewtun * add GPT2 config init to GPT2 config + copie shebang for fix-copies Co-authored-by: ChainYo Co-authored-by: lewtun --- docs/source/serialization.mdx | 1 + src/transformers/models/gptj/__init__.py | 4 +- .../models/gptj/configuration_gptj.py | 85 +++++++++++++++++++ src/transformers/onnx/features.py | 10 +++ 4 files changed, 98 insertions(+), 2 deletions(-) diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index 633d468301..0cb5f51c69 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -54,6 +54,7 @@ Ready-made configurations include the following architectures: - ELECTRA - FlauBERT - GPT Neo +- GPT-J - I-BERT - LayoutLM - M2M100 diff --git a/src/transformers/models/gptj/__init__.py b/src/transformers/models/gptj/__init__.py index cbebc85279..69ca43f276 100644 --- a/src/transformers/models/gptj/__init__.py +++ b/src/transformers/models/gptj/__init__.py @@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available _import_structure = { - "configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], + "configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"], } if is_torch_available(): @@ -43,7 +43,7 @@ if is_flax_available(): if TYPE_CHECKING: - from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig + from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig if is_torch_available(): from .modeling_gptj import ( diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index e30cf2479b..25a193cdd9 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -13,8 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """ GPT-J model configuration""" +from collections import OrderedDict +from typing import Any, List, Mapping, Optional +from ... import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast, PatchingSpec from ...utils import logging @@ -135,3 +139,84 @@ class GPTJConfig(PretrainedConfig): super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs ) + + +# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig +class GPTJOnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 4e5bd8e9d3..d8e487ff14 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -11,6 +11,7 @@ from ..models.electra import ElectraOnnxConfig from ..models.flaubert import FlaubertOnnxConfig from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig +from ..models.gptj import GPTJOnnxConfig from ..models.ibert import IBertOnnxConfig from ..models.layoutlm import LayoutLMOnnxConfig from ..models.m2m_100 import M2M100OnnxConfig @@ -233,6 +234,15 @@ class FeaturesManager: "token-classification", onnx_config_cls=GPT2OnnxConfig, ), + "gpt-j": supported_features_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + "question-answering", + "sequence-classification", + onnx_config_cls=GPTJOnnxConfig, + ), "gpt-neo": supported_features_mapping( "default", "default-with-past",