@@ -408,12 +408,12 @@ class FeaturesManager:
|
||||
"question-answering",
|
||||
onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
|
||||
),
|
||||
"mobilenet_v1": supported_features_mapping(
|
||||
"mobilenet-v1": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig",
|
||||
),
|
||||
"mobilenet_v2": supported_features_mapping(
|
||||
"mobilenet-v2": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig",
|
||||
|
||||
@@ -272,7 +272,12 @@ def _get_models_to_test(export_models_list):
|
||||
feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
|
||||
}
|
||||
else:
|
||||
feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name)
|
||||
# pre-process the model names
|
||||
model_type = name.replace("_", "-")
|
||||
model_name = getattr(model, "name", "")
|
||||
feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(
|
||||
model_type, model_name=model_name
|
||||
)
|
||||
|
||||
for feature, onnx_config_class_constructor in feature_config_mapping.items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
|
||||
Reference in New Issue
Block a user