From 76454b08c8ec09b0debeb1c94a3855cde8167d84 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Thu, 18 Aug 2022 15:13:54 +0200 Subject: [PATCH] Rename second input dimension from "sequence" to "num_channels" for CV models (#17976) --- src/transformers/models/beit/configuration_beit.py | 2 +- src/transformers/models/convnext/configuration_convnext.py | 2 +- .../models/data2vec/configuration_data2vec_vision.py | 2 +- src/transformers/models/deit/configuration_deit.py | 2 +- src/transformers/models/detr/configuration_detr.py | 4 ++-- .../models/layoutlmv3/configuration_layoutlmv3.py | 2 +- src/transformers/models/mobilevit/configuration_mobilevit.py | 2 +- src/transformers/models/resnet/configuration_resnet.py | 2 +- src/transformers/models/vit/configuration_vit.py | 2 +- tests/onnx/test_onnx_v2.py | 1 + 10 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index c745f3227d..092f33ad85 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -194,7 +194,7 @@ class BeitOnnxConfig(OnnxConfig): def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ - ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), ] ) diff --git a/src/transformers/models/convnext/configuration_convnext.py b/src/transformers/models/convnext/configuration_convnext.py index 9f77c00992..0b31da4370 100644 --- a/src/transformers/models/convnext/configuration_convnext.py +++ b/src/transformers/models/convnext/configuration_convnext.py @@ -117,7 +117,7 @@ class ConvNextOnnxConfig(OnnxConfig): def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ - ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), ] ) diff --git a/src/transformers/models/data2vec/configuration_data2vec_vision.py b/src/transformers/models/data2vec/configuration_data2vec_vision.py index a7dd85b817..d6fc787176 100644 --- a/src/transformers/models/data2vec/configuration_data2vec_vision.py +++ b/src/transformers/models/data2vec/configuration_data2vec_vision.py @@ -193,7 +193,7 @@ class Data2VecVisionOnnxConfig(OnnxConfig): def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ - ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), ] ) diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py index df74664ace..1e9154eeca 100644 --- a/src/transformers/models/deit/configuration_deit.py +++ b/src/transformers/models/deit/configuration_deit.py @@ -137,7 +137,7 @@ class DeiTOnnxConfig(OnnxConfig): def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ - ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), ] ) diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index fa8086efc4..e46a5d610e 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -220,8 +220,8 @@ class DetrOnnxConfig(OnnxConfig): def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ - ("pixel_values", {0: "batch", 1: "sequence"}), - ("pixel_mask", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), + ("pixel_mask", {0: "batch"}), ] ) diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py index d9ddde6289..ddf86ceaa1 100644 --- a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py @@ -212,7 +212,7 @@ class LayoutLMv3OnnxConfig(OnnxConfig): ("input_ids", {0: "batch", 1: "sequence"}), ("bbox", {0: "batch", 1: "sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}), - ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), ] ) diff --git a/src/transformers/models/mobilevit/configuration_mobilevit.py b/src/transformers/models/mobilevit/configuration_mobilevit.py index 87a8a009dd..e2b2c568f6 100644 --- a/src/transformers/models/mobilevit/configuration_mobilevit.py +++ b/src/transformers/models/mobilevit/configuration_mobilevit.py @@ -171,7 +171,7 @@ class MobileViTOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict([("pixel_values", {0: "batch"})]) + return OrderedDict([("pixel_values", {0: "batch", 1: "num_channels"})]) @property def outputs(self) -> Mapping[str, Mapping[int, str]]: diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 9bfc694bb1..61a7fc86de 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -105,7 +105,7 @@ class ResNetOnnxConfig(OnnxConfig): def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ - ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), ] ) diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py index e84fc6c25f..a65790f301 100644 --- a/src/transformers/models/vit/configuration_vit.py +++ b/src/transformers/models/vit/configuration_vit.py @@ -135,7 +135,7 @@ class ViTOnnxConfig(OnnxConfig): def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ - ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels"}), ] ) diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index b3c0ffb1f3..829c7ec0a4 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -199,6 +199,7 @@ PYTORCH_EXPORT_MODELS = { ("roformer", "junnyu/roformer_chinese_base"), ("squeezebert", "squeezebert/squeezebert-uncased"), ("mobilebert", "google/mobilebert-uncased"), + ("mobilevit", "apple/mobilevit-small"), ("xlm", "xlm-clm-ende-1024"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"),