Use tiny models for ONNX tests - text modality (#20333)
* Use tiny ONNX models * Fix broken tests * Add tiny perceiver * Add tiny convbert
This commit is contained in:
@@ -179,48 +179,48 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
PYTORCH_EXPORT_MODELS = {
|
PYTORCH_EXPORT_MODELS = {
|
||||||
("albert", "hf-internal-testing/tiny-albert"),
|
("albert", "hf-internal-testing/tiny-random-AlbertModel"),
|
||||||
("bert", "bert-base-cased"),
|
("bert", "hf-internal-testing/tiny-random-BertModel"),
|
||||||
("big-bird", "google/bigbird-roberta-base"),
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
("ibert", "kssteven/ibert-roberta-base"),
|
("big-bird", "hf-internal-testing/tiny-random-BigBirdModel"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
("clip", "openai/clip-vit-base-patch32"),
|
("clip", "hf-internal-testing/tiny-random-CLIPModel"),
|
||||||
("convbert", "YituTech/conv-bert-base"),
|
("convbert", "hf-internal-testing/tiny-random-ConvBertModel"),
|
||||||
("codegen", "Salesforce/codegen-350M-multi"),
|
("codegen", "hf-internal-testing/tiny-random-CodeGenModel"),
|
||||||
("deberta", "microsoft/deberta-base"),
|
("data2vec-text", "hf-internal-testing/tiny-random-Data2VecTextModel"),
|
||||||
("deberta-v2", "microsoft/deberta-v2-xlarge"),
|
("data2vec-vision", "facebook/data2vec-vision-base"),
|
||||||
|
("deberta", "hf-internal-testing/tiny-random-DebertaModel"),
|
||||||
|
("deberta-v2", "hf-internal-testing/tiny-random-DebertaV2Model"),
|
||||||
|
("deit", "facebook/deit-small-patch16-224"),
|
||||||
("convnext", "facebook/convnext-tiny-224"),
|
("convnext", "facebook/convnext-tiny-224"),
|
||||||
("detr", "facebook/detr-resnet-50"),
|
("detr", "facebook/detr-resnet-50"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "hf-internal-testing/tiny-random-DistilBertModel"),
|
||||||
("electra", "google/electra-base-generator"),
|
("electra", "hf-internal-testing/tiny-random-ElectraModel"),
|
||||||
|
("groupvit", "nvidia/groupvit-gcc-yfcc"),
|
||||||
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
("imagegpt", "openai/imagegpt-small"),
|
("imagegpt", "openai/imagegpt-small"),
|
||||||
("resnet", "microsoft/resnet-50"),
|
("levit", "facebook/levit-128S"),
|
||||||
("roberta", "roberta-base"),
|
("layoutlm", "hf-internal-testing/tiny-random-LayoutLMModel"),
|
||||||
("roformer", "junnyu/roformer_chinese_base"),
|
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
||||||
("squeezebert", "squeezebert/squeezebert-uncased"),
|
("longformer", "allenai/longformer-base-4096"),
|
||||||
("mobilebert", "google/mobilebert-uncased"),
|
("mobilebert", "hf-internal-testing/tiny-random-MobileBertModel"),
|
||||||
("mobilenet_v1", "google/mobilenet_v1_0.75_192"),
|
("mobilenet_v1", "google/mobilenet_v1_0.75_192"),
|
||||||
("mobilenet_v2", "google/mobilenet_v2_0.35_96"),
|
("mobilenet_v2", "google/mobilenet_v2_0.35_96"),
|
||||||
("mobilevit", "apple/mobilevit-small"),
|
("mobilevit", "apple/mobilevit-small"),
|
||||||
("xlm", "xlm-clm-ende-1024"),
|
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
|
||||||
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
|
||||||
("groupvit", "nvidia/groupvit-gcc-yfcc"),
|
|
||||||
("levit", "facebook/levit-128S"),
|
|
||||||
("owlvit", "google/owlvit-base-patch32"),
|
("owlvit", "google/owlvit-base-patch32"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
|
||||||
("deit", "facebook/deit-small-patch16-224"),
|
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
|
||||||
("beit", "microsoft/beit-base-patch16-224"),
|
("resnet", "microsoft/resnet-50"),
|
||||||
("data2vec-text", "facebook/data2vec-text-base"),
|
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
|
||||||
("data2vec-vision", "facebook/data2vec-vision-base"),
|
("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),
|
||||||
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
|
|
||||||
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
|
|
||||||
("longformer", "allenai/longformer-base-4096"),
|
|
||||||
("yolos", "hustvl/yolos-tiny"),
|
|
||||||
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
|
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
|
||||||
|
("squeezebert", "hf-internal-testing/tiny-random-SqueezeBertModel"),
|
||||||
("swin", "microsoft/swin-tiny-patch4-window7-224"),
|
("swin", "microsoft/swin-tiny-patch4-window7-224"),
|
||||||
|
("vit", "google/vit-base-patch16-224"),
|
||||||
|
("yolos", "hustvl/yolos-tiny"),
|
||||||
("whisper", "openai/whisper-tiny.en"),
|
("whisper", "openai/whisper-tiny.en"),
|
||||||
|
("xlm", "hf-internal-testing/tiny-random-XLMModel"),
|
||||||
|
("xlm-roberta", "hf-internal-testing/tiny-random-XLMRobertaXLModel"),
|
||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
|
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
|
||||||
@@ -228,34 +228,31 @@ PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
("bloom", "bigscience/bloom-560m"),
|
("bloom", "hf-internal-testing/tiny-random-BloomModel"),
|
||||||
("gpt2", "gpt2"),
|
("gpt2", "hf-internal-testing/tiny-random-GPT2Model"),
|
||||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
("gpt-neo", "hf-internal-testing/tiny-random-GPTNeoModel"),
|
||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||||
("bart", "facebook/bart-base"),
|
("bart", "hf-internal-testing/tiny-random-BartModel"),
|
||||||
("mbart", "sshleifer/tiny-mbart"),
|
("bigbird-pegasus", "hf-internal-testing/tiny-random-BigBirdPegasusModel"),
|
||||||
("t5", "t5-small"),
|
|
||||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
|
||||||
("mt5", "google/mt5-base"),
|
|
||||||
("m2m-100", "facebook/m2m100_418M"),
|
|
||||||
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
||||||
("blenderbot", "facebook/blenderbot-400M-distill"),
|
("blenderbot", "hf-internal-testing/tiny-random-BlenderbotModel"),
|
||||||
("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
|
("longt5", "hf-internal-testing/tiny-random-LongT5Model"),
|
||||||
("longt5", "google/long-t5-local-base"),
|
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||||
# Disable for now as it causes fatal error `Floating point exception (core dumped)` and the subsequential tests are
|
("mbart", "sshleifer/tiny-mbart"),
|
||||||
# not run.
|
("mt5", "google/mt5-base"),
|
||||||
# ("longt5", "google/long-t5-tglobal-base"),
|
("m2m-100", "hf-internal-testing/tiny-random-M2M100Model"),
|
||||||
|
("t5", "hf-internal-testing/tiny-random-T5Model"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||||
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
||||||
("albert", "hf-internal-testing/tiny-albert"),
|
("albert", "hf-internal-testing/tiny-albert"),
|
||||||
("bert", "bert-base-cased"),
|
("bert", "hf-internal-testing/tiny-random-BertModel"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "hf-internal-testing/tiny-random-DistilBertModel"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||||
|
|||||||
Reference in New Issue
Block a user