add GPT-J ONNX config to Transformers (#16274)
* add GPT-J ONNX config to Transformers * remove token-classification features mapping Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * add question-answering features mapping Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * add GPT2 config init to GPT2 config + copie shebang for fix-copies Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -54,6 +54,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- ELECTRA
|
- ELECTRA
|
||||||
- FlauBERT
|
- FlauBERT
|
||||||
- GPT Neo
|
- GPT Neo
|
||||||
|
- GPT-J
|
||||||
- I-BERT
|
- I-BERT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
- M2M100
|
- M2M100
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
|
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -43,7 +43,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
|
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_gptj import (
|
from .modeling_gptj import (
|
||||||
|
|||||||
@@ -13,8 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" GPT-J model configuration"""
|
""" GPT-J model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -135,3 +139,84 @@ class GPTJConfig(PretrainedConfig):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
|
||||||
|
class GPTJOnnxConfig(OnnxConfigWithPast):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
task: str = "default",
|
||||||
|
patching_specs: List[PatchingSpec] = None,
|
||||||
|
use_past: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||||
|
if not getattr(self._config, "pad_token_id", None):
|
||||||
|
# TODO: how to do that better?
|
||||||
|
self._config.pad_token_id = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||||
|
if self.use_past:
|
||||||
|
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_layers(self) -> int:
|
||||||
|
return self._config.n_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_attention_heads(self) -> int:
|
||||||
|
return self._config.n_head
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
|
tokenizer, batch_size, seq_length, is_pair, framework
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to order the input in the way they appears in the forward()
|
||||||
|
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||||
|
|
||||||
|
# Need to add the past_keys
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch, seqlen = common_inputs["input_ids"].shape
|
||||||
|
# Not using the same length for past_key_values
|
||||||
|
past_key_values_length = seqlen + 2
|
||||||
|
past_shape = (
|
||||||
|
batch,
|
||||||
|
self.num_attention_heads,
|
||||||
|
past_key_values_length,
|
||||||
|
self._config.hidden_size // self.num_attention_heads,
|
||||||
|
)
|
||||||
|
ordered_inputs["past_key_values"] = [
|
||||||
|
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||||
|
if self.use_past:
|
||||||
|
ordered_inputs["attention_mask"] = torch.cat(
|
||||||
|
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
return ordered_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 13
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from ..models.electra import ElectraOnnxConfig
|
|||||||
from ..models.flaubert import FlaubertOnnxConfig
|
from ..models.flaubert import FlaubertOnnxConfig
|
||||||
from ..models.gpt2 import GPT2OnnxConfig
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
|
from ..models.gptj import GPTJOnnxConfig
|
||||||
from ..models.ibert import IBertOnnxConfig
|
from ..models.ibert import IBertOnnxConfig
|
||||||
from ..models.layoutlm import LayoutLMOnnxConfig
|
from ..models.layoutlm import LayoutLMOnnxConfig
|
||||||
from ..models.m2m_100 import M2M100OnnxConfig
|
from ..models.m2m_100 import M2M100OnnxConfig
|
||||||
@@ -233,6 +234,15 @@ class FeaturesManager:
|
|||||||
"token-classification",
|
"token-classification",
|
||||||
onnx_config_cls=GPT2OnnxConfig,
|
onnx_config_cls=GPT2OnnxConfig,
|
||||||
),
|
),
|
||||||
|
"gpt-j": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"default-with-past",
|
||||||
|
"causal-lm",
|
||||||
|
"causal-lm-with-past",
|
||||||
|
"question-answering",
|
||||||
|
"sequence-classification",
|
||||||
|
onnx_config_cls=GPTJOnnxConfig,
|
||||||
|
),
|
||||||
"gpt-neo": supported_features_mapping(
|
"gpt-neo": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
|||||||
Reference in New Issue
Block a user