Layoutlm onnx support (Issue #13300) (#13562)

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Removed regression/ folder

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Fixed import error

* Remove unnecessary import statements

* Changed max_2d_positions from class variable to instance variable of the config class

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Add support for exporting PyTorch LayoutLM to ONNX

* cleanup

* Fixed import error

* Changed max_2d_positions from class variable to instance variable of the config class

* Use super class generate_dummy_inputs method

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Add support for Masked LM, sequence classification and token classification

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Removed uncessary import and method

* Fixed code styling

* Raise error if PyTorch is not installed

* Remove unnecessary import statement

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
Nishant Prabhu
2021-09-22 01:09:37 +05:30
committed by GitHub
parent b7d264be0d
commit ddd4d02f30
4 changed files with 84 additions and 2 deletions

View File

@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
LayoutLMConfig,
MBartConfig,
RobertaConfig,
XLMRobertaConfig,
@@ -23,6 +24,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.layoutlm import LayoutLMOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
@@ -193,6 +195,7 @@ if is_torch_available():
DistilBertModel,
GPT2Model,
GPTNeoModel,
LayoutLMModel,
MBartModel,
RobertaModel,
XLMRobertaModel,
@@ -208,6 +211,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}