From 52dd2b61bff8af5b6409fdd5ec92a9b3114f3636 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 22 Dec 2022 18:52:54 +0100 Subject: [PATCH] [`MobileNet-v2`] Fix ONNX typo (#20860) * fix typo `onnx` * fix test --- src/transformers/onnx/features.py | 4 ++-- tests/onnx/test_onnx_v2.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 0fb750ebc5..b371844c9d 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -408,12 +408,12 @@ class FeaturesManager: "question-answering", onnx_config_cls="models.mobilebert.MobileBertOnnxConfig", ), - "mobilenet_v1": supported_features_mapping( + "mobilenet-v1": supported_features_mapping( "default", "image-classification", onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig", ), - "mobilenet_v2": supported_features_mapping( + "mobilenet-v2": supported_features_mapping( "default", "image-classification", onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index fbc959284d..afe4c1c036 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -272,7 +272,12 @@ def _get_models_to_test(export_models_list): feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _ } else: - feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name) + # pre-process the model names + model_type = name.replace("_", "-") + model_name = getattr(model, "name", "") + feature_config_mapping = FeaturesManager.get_supported_features_for_model_type( + model_type, model_name=model_name + ) for feature, onnx_config_class_constructor in feature_config_mapping.items(): models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))