GPT-Neo ONNX export (#12911)
GPT-Neo ONNX export and task / feature refactoring Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -9,6 +9,7 @@ from transformers import ( # LongformerConfig,; T5Config,
|
||||
BartConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
RobertaConfig,
|
||||
XLMRobertaConfig,
|
||||
is_torch_available,
|
||||
@@ -20,6 +21,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
|
||||
# from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
@@ -151,7 +153,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||
with self.subTest(name):
|
||||
self.assertFalse(
|
||||
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
|
||||
OnnxConfigWithPast.from_model_config(config()).use_past,
|
||||
"OnnxConfigWithPast.from_model_config() should not use_past",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
@@ -167,7 +170,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
with self.subTest(name):
|
||||
|
||||
# without past
|
||||
onnx_config_default = OnnxConfigWithPast.default(config())
|
||||
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
||||
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
|
||||
self.assertFalse(
|
||||
@@ -190,6 +193,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
DistilBertModel,
|
||||
GPT2Model,
|
||||
GPTNeoModel,
|
||||
RobertaModel,
|
||||
XLMRobertaModel,
|
||||
)
|
||||
@@ -200,6 +204,7 @@ if is_torch_available():
|
||||
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
||||
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
||||
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
|
||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||
|
||||
Reference in New Issue
Block a user