Turn on eval mode when exporting to ONNX (#12758)
* Set model in eval mode when exporting to ONNX. * Disable t5 for now. * Disable T5 with past too. * Style.
This commit is contained in:
@@ -87,6 +87,7 @@ def export(
|
||||
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||
torch.set_grad_enabled(False)
|
||||
model.config.return_dict = True
|
||||
model.eval()
|
||||
|
||||
# Check if we need to override certain configuration item
|
||||
if config.values_override is not None:
|
||||
|
||||
@@ -3,14 +3,13 @@ from tempfile import NamedTemporaryFile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import ( # LongformerConfig,
|
||||
from transformers import ( # LongformerConfig,; T5Config,
|
||||
AlbertConfig,
|
||||
AutoTokenizer,
|
||||
BartConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
XLMRobertaConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
@@ -22,7 +21,8 @@ from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
# from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
from transformers.models.t5 import T5OnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
@@ -122,7 +122,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
|
||||
"""
|
||||
|
||||
SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)}
|
||||
SUPPORTED_WITH_PAST_CONFIGS = {
|
||||
("BART", BartConfig),
|
||||
("GPT2", GPT2Config),
|
||||
# ("T5", T5Config)
|
||||
}
|
||||
|
||||
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
||||
def test_use_past(self):
|
||||
@@ -165,14 +169,13 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (
|
||||
from transformers import ( # T5Model,
|
||||
AlbertModel,
|
||||
BartModel,
|
||||
BertModel,
|
||||
DistilBertModel,
|
||||
GPT2Model,
|
||||
RobertaModel,
|
||||
T5Model,
|
||||
XLMRobertaModel,
|
||||
)
|
||||
|
||||
@@ -185,7 +188,7 @@ if is_torch_available():
|
||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
|
||||
Reference in New Issue
Block a user