This reverts commit 0c70f145d1.
This commit is contained in:
@@ -3,8 +3,33 @@ from tempfile import NamedTemporaryFile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import ( # LongformerConfig,; T5Config,
|
||||
AlbertConfig,
|
||||
AutoTokenizer,
|
||||
BartConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
LayoutLMConfig,
|
||||
MBartConfig,
|
||||
RobertaConfig,
|
||||
XLMRobertaConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.models.albert import AlbertOnnxConfig
|
||||
from transformers.models.bart import BartOnnxConfig
|
||||
from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
|
||||
from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
|
||||
# from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||
from transformers.models.layoutlm import LayoutLMOnnxConfig
|
||||
from transformers.models.mbart import MBartOnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
@@ -12,8 +37,7 @@ from transformers.onnx import (
|
||||
export,
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import OnnxConfigWithPast
|
||||
from transformers.onnx.features import FeaturesManager
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||
|
||||
@@ -115,12 +139,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
|
||||
"""
|
||||
|
||||
SUPPORTED_WITH_PAST_CONFIGS = {}
|
||||
# SUPPORTED_WITH_PAST_CONFIGS = {
|
||||
# ("BART", BartConfig),
|
||||
# ("GPT2", GPT2Config),
|
||||
# # ("T5", T5Config)
|
||||
# }
|
||||
SUPPORTED_WITH_PAST_CONFIGS = {
|
||||
("BART", BartConfig),
|
||||
("GPT2", GPT2Config),
|
||||
# ("T5", T5Config)
|
||||
}
|
||||
|
||||
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
||||
def test_use_past(self):
|
||||
@@ -164,37 +187,40 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
)
|
||||
|
||||
|
||||
PYTORCH_EXPORT_MODELS = {
|
||||
("albert", "hf-internal-testing/tiny-albert"),
|
||||
("bert", "bert-base-cased"),
|
||||
("camembert", "camembert-base"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
# ("longFormer", "longformer-base-4096"),
|
||||
("roberta", "roberta-base"),
|
||||
("xlm-roberta", "xlm-roberta-base"),
|
||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||
}
|
||||
if is_torch_available():
|
||||
from transformers import ( # T5Model,
|
||||
AlbertModel,
|
||||
BartModel,
|
||||
BertModel,
|
||||
DistilBertModel,
|
||||
GPT2Model,
|
||||
GPTNeoModel,
|
||||
LayoutLMModel,
|
||||
MBartModel,
|
||||
RobertaModel,
|
||||
XLMRobertaModel,
|
||||
)
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||
}
|
||||
PYTORCH_EXPORT_DEFAULT_MODELS = {
|
||||
("ALBERT", "hf-internal-testing/tiny-albert", AlbertModel, AlbertConfig, AlbertOnnxConfig),
|
||||
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
|
||||
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
||||
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
||||
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
|
||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
|
||||
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
|
||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("bart", "facebook/bart-base"),
|
||||
("mbart", "sshleifer/tiny-mbart"),
|
||||
("t5", "t5-small"),
|
||||
}
|
||||
|
||||
|
||||
def _get_models_to_test(export_models_list):
|
||||
models_to_test = []
|
||||
for (name, model) in export_models_list:
|
||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||
name
|
||||
).items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
return models_to_test
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
# ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
|
||||
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
|
||||
}
|
||||
|
||||
|
||||
class OnnxExportTestCaseV2(TestCase):
|
||||
@@ -202,52 +228,52 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
Integration tests ensuring supported models are correctly exported
|
||||
"""
|
||||
|
||||
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_default(self):
|
||||
from transformers.onnx import export
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
|
||||
with self.subTest(name):
|
||||
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
|
||||
|
||||
# Useful for causal lm models that do not use pad tokens.
|
||||
if not getattr(config, "pad_token_id", None):
|
||||
config.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model = model_class(config_class.from_pretrained(model))
|
||||
onnx_config = onnx_config_class.from_model_config(model.config)
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
model = model_class.from_config(config)
|
||||
onnx_config = onnx_config_class_constructor(model.config)
|
||||
with NamedTemporaryFile("w") as output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
|
||||
)
|
||||
|
||||
with NamedTemporaryFile("w") as output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
|
||||
)
|
||||
try:
|
||||
validate_model_outputs(
|
||||
onnx_config,
|
||||
tokenizer,
|
||||
model,
|
||||
Path(output.name),
|
||||
onnx_outputs,
|
||||
onnx_config.atol_for_validation,
|
||||
try:
|
||||
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
|
||||
except ValueError as ve:
|
||||
self.fail(f"{name} -> {ve}")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_with_past(self):
|
||||
from transformers.onnx import export
|
||||
|
||||
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
|
||||
with self.subTest(name):
|
||||
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model = model_class(config_class())
|
||||
onnx_config = onnx_config_class.with_past(model.config)
|
||||
|
||||
self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
|
||||
self.assertTrue(
|
||||
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
|
||||
)
|
||||
except ValueError as ve:
|
||||
self.fail(f"{name}, {feature} -> {ve}")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
with NamedTemporaryFile("w") as output:
|
||||
output = Path(output.name)
|
||||
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
try:
|
||||
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
|
||||
except ValueError as ve:
|
||||
self.fail(f"{name} -> {ve}")
|
||||
|
||||
Reference in New Issue
Block a user