@@ -408,12 +408,12 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
|
onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
|
||||||
),
|
),
|
||||||
"mobilenet_v1": supported_features_mapping(
|
"mobilenet-v1": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"image-classification",
|
"image-classification",
|
||||||
onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig",
|
onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig",
|
||||||
),
|
),
|
||||||
"mobilenet_v2": supported_features_mapping(
|
"mobilenet-v2": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"image-classification",
|
"image-classification",
|
||||||
onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig",
|
onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig",
|
||||||
|
|||||||
@@ -272,7 +272,12 @@ def _get_models_to_test(export_models_list):
|
|||||||
feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
|
feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name)
|
# pre-process the model names
|
||||||
|
model_type = name.replace("_", "-")
|
||||||
|
model_name = getattr(model, "name", "")
|
||||||
|
feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(
|
||||||
|
model_type, model_name=model_name
|
||||||
|
)
|
||||||
|
|
||||||
for feature, onnx_config_class_constructor in feature_config_mapping.items():
|
for feature, onnx_config_class_constructor in feature_config_mapping.items():
|
||||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||||
|
|||||||
Reference in New Issue
Block a user