Fix duplicate arguments passed to dummy inputs in ONNX export (#16045)
* Fix duplicate arguments passed to dummy inputs in ONNX export * Fix M2M100 ONNX config * Ensure we check PreTrained model only if torch is available * Remove TensorFlow tests for models without PyTorch parity
This commit is contained in:
@@ -196,28 +196,19 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("m2m-100", "facebook/m2m100_418M"),
|
||||
}
|
||||
|
||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
||||
("albert", "hf-internal-testing/tiny-albert"),
|
||||
("bert", "bert-base-cased"),
|
||||
("ibert", "kssteven/ibert-roberta-base"),
|
||||
("camembert", "camembert-base"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
("roberta", "roberta-base"),
|
||||
("xlm-roberta", "xlm-roberta-base"),
|
||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||
}
|
||||
|
||||
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||
}
|
||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {}
|
||||
|
||||
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("bart", "facebook/bart-base"),
|
||||
("mbart", "sshleifer/tiny-mbart"),
|
||||
("t5", "t5-small"),
|
||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||
}
|
||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
|
||||
|
||||
|
||||
def _get_models_to_test(export_models_list):
|
||||
@@ -312,13 +303,13 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_seq2seq_with_past(
|
||||
|
||||
Reference in New Issue
Block a user