From b424f0b4a301abcbf3c282114159371ee44c3e01 Mon Sep 17 00:00:00 2001 From: mrbean <43734688+sam-h-bean@users.noreply.github.com> Date: Tue, 28 Jun 2022 08:57:53 -0400 Subject: [PATCH] Mrbean/codegen onnx (#17903) --- src/transformers/onnx/features.py | 5 +++++ tests/onnx/test_onnx_v2.py | 1 + 2 files changed, 6 insertions(+) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 9a76cfc012..c37c12ca2a 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -188,6 +188,11 @@ class FeaturesManager: "question-answering", onnx_config_cls="models.camembert.CamembertOnnxConfig", ), + "codegen": supported_features_mapping( + "default", + "causal-lm", + onnx_config_cls="models.codegen.CodeGenOnnxConfig", + ), "convbert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index f409b36f91..50601598f5 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ("ibert", "kssteven/ibert-roberta-base"), ("camembert", "camembert-base"), ("convbert", "YituTech/conv-bert-base"), + ("codegen", "Salesforce/codegen-350M-multi"), ("deberta", "microsoft/deberta-base"), ("deberta-v2", "microsoft/deberta-v2-xlarge"), ("convnext", "facebook/convnext-tiny-224"),