* Add support for exporting PyTorch LayoutLM to ONNX * Added tests for converting LayoutLM to ONNX * Add support for exporting PyTorch LayoutLM to ONNX * Added tests for converting LayoutLM to ONNX * cleanup * Removed regression/ folder * Add support for exporting PyTorch LayoutLM to ONNX * Added tests for converting LayoutLM to ONNX * cleanup * Fixed import error * Remove unnecessary import statements * Changed max_2d_positions from class variable to instance variable of the config class * Add support for exporting PyTorch LayoutLM to ONNX * Added tests for converting LayoutLM to ONNX * cleanup * Add support for exporting PyTorch LayoutLM to ONNX * cleanup * Fixed import error * Changed max_2d_positions from class variable to instance variable of the config class * Use super class generate_dummy_inputs method Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> * Add support for Masked LM, sequence classification and token classification Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> * Removed uncessary import and method * Fixed code styling * Raise error if PyTorch is not installed * Remove unnecessary import statement Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
@@ -24,7 +24,7 @@ from .tokenization_layoutlm import LayoutLMTokenizer
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig"],
|
"configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMOnnxConfig"],
|
||||||
"tokenization_layoutlm": ["LayoutLMTokenizer"],
|
"tokenization_layoutlm": ["LayoutLMTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
|
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMOnnxConfig
|
||||||
from .tokenization_layoutlm import LayoutLMTokenizer
|
from .tokenization_layoutlm import LayoutLMTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -13,8 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" LayoutLM model configuration """
|
""" LayoutLM model configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||||
|
|
||||||
|
from ... import is_torch_available
|
||||||
|
from ...onnx import OnnxConfig, PatchingSpec
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..bert.configuration_bert import BertConfig
|
from ..bert.configuration_bert import BertConfig
|
||||||
|
|
||||||
@@ -125,3 +130,68 @@ class LayoutLMConfig(BertConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.max_2d_position_embeddings = max_2d_position_embeddings
|
self.max_2d_position_embeddings = max_2d_position_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutLMOnnxConfig(OnnxConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
task: str = "default",
|
||||||
|
patching_specs: List[PatchingSpec] = None,
|
||||||
|
):
|
||||||
|
super().__init__(config, task=task, patching_specs=patching_specs)
|
||||||
|
self.max_2d_positions = config.max_2d_position_embeddings - 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("bbox", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
("token_type_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate inputs to provide to the ONNX exporter for the specific framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: The tokenizer associated with this model configuration
|
||||||
|
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
|
||||||
|
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
|
||||||
|
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
|
||||||
|
framework: The framework (optional) the tokenizer will generate tensor for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||||
|
|
||||||
|
# Generate a dummy bbox
|
||||||
|
box = [48, 84, 73, 128]
|
||||||
|
|
||||||
|
if not framework == TensorType.PYTORCH:
|
||||||
|
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
|
||||||
|
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
|
||||||
|
import torch
|
||||||
|
|
||||||
|
input_dict["bbox"] = torch.tensor(
|
||||||
|
[
|
||||||
|
[0] * 4,
|
||||||
|
*[box] * seq_length,
|
||||||
|
[self.max_2d_positions] * 4,
|
||||||
|
]
|
||||||
|
).tile(batch_size, 1, 1)
|
||||||
|
return input_dict
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from ..models.bert import BertOnnxConfig
|
|||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.gpt2 import GPT2OnnxConfig
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
|
from ..models.layoutlm import LayoutLMOnnxConfig
|
||||||
from ..models.longformer import LongformerOnnxConfig
|
from ..models.longformer import LongformerOnnxConfig
|
||||||
from ..models.mbart import MBartOnnxConfig
|
from ..models.mbart import MBartOnnxConfig
|
||||||
from ..models.roberta import RobertaOnnxConfig
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
@@ -78,6 +79,13 @@ class FeaturesManager:
|
|||||||
"sequence-classification-with-past",
|
"sequence-classification-with-past",
|
||||||
onnx_config_cls=GPTNeoOnnxConfig,
|
onnx_config_cls=GPTNeoOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"layoutlm": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
onnx_config_cls=LayoutLMOnnxConfig,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
|
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
|
|||||||
DistilBertConfig,
|
DistilBertConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPTNeoConfig,
|
GPTNeoConfig,
|
||||||
|
LayoutLMConfig,
|
||||||
MBartConfig,
|
MBartConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
@@ -23,6 +24,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
|
|||||||
# from transformers.models.longformer import LongformerOnnxConfig
|
# from transformers.models.longformer import LongformerOnnxConfig
|
||||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||||
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||||
|
from transformers.models.layoutlm import LayoutLMOnnxConfig
|
||||||
from transformers.models.mbart import MBartOnnxConfig
|
from transformers.models.mbart import MBartOnnxConfig
|
||||||
from transformers.models.roberta import RobertaOnnxConfig
|
from transformers.models.roberta import RobertaOnnxConfig
|
||||||
|
|
||||||
@@ -193,6 +195,7 @@ if is_torch_available():
|
|||||||
DistilBertModel,
|
DistilBertModel,
|
||||||
GPT2Model,
|
GPT2Model,
|
||||||
GPTNeoModel,
|
GPTNeoModel,
|
||||||
|
LayoutLMModel,
|
||||||
MBartModel,
|
MBartModel,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
XLMRobertaModel,
|
XLMRobertaModel,
|
||||||
@@ -208,6 +211,7 @@ if is_torch_available():
|
|||||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||||
|
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
|
||||||
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
|
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
|
||||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user