Fix duplicate arguments passed to dummy inputs in ONNX export (#16045)

* Fix duplicate arguments passed to dummy inputs in ONNX export

* Fix M2M100 ONNX config

* Ensure we check PreTrained model only if torch is available

* Remove TensorFlow tests for models without PyTorch parity
This commit is contained in:
lewtun
2022-03-10 20:19:45 +01:00
committed by GitHub
parent ba21001f4c
commit 6b09328368
3 changed files with 42 additions and 31 deletions

View File

@@ -196,28 +196,19 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("m2m-100", "facebook/m2m100_418M"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
}
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {}
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("bart", "facebook/bart-base"),
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
def _get_models_to_test(export_models_list):
@@ -312,13 +303,13 @@ class OnnxExportTestCaseV2(TestCase):
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
@slow
@require_tf
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
@slow
@require_tf
def test_tensorflow_export_seq2seq_with_past(