Determine framework automatically before ONNX export (#18615)
* Automatic detection for framework to use when exporting to ONNX * Log message change * Incorporating PR comments, adding unit test * Adding tf for pip install for run_tests_onnxruntime CI * Restoring past changes to circleci yaml and test_onnx_v2.py, tests moved to tests/onnx/test_features.py * Fixup * Adding test to fetcher * Updating circleci config to log more * Changing test class name * Comment typo fix in tests/onnx/test_features.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Moving torch_str/tf_str to self.framework_pt/tf * Remove -rA flag in circleci config Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
111
tests/onnx/test_features.py
Normal file
111
tests/onnx/test_features.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from transformers import AutoModel, TFAutoModel
|
||||
from transformers.onnx import FeaturesManager
|
||||
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_tf
|
||||
class DetermineFrameworkTest(TestCase):
|
||||
"""
|
||||
Test `FeaturesManager.determine_framework`
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.test_model = SMALL_MODEL_IDENTIFIER
|
||||
self.framework_pt = "pt"
|
||||
self.framework_tf = "tf"
|
||||
|
||||
def _setup_pt_ckpt(self, save_dir):
|
||||
model_pt = AutoModel.from_pretrained(self.test_model)
|
||||
model_pt.save_pretrained(save_dir)
|
||||
|
||||
def _setup_tf_ckpt(self, save_dir):
|
||||
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True)
|
||||
model_tf.save_pretrained(save_dir)
|
||||
|
||||
def test_framework_provided(self):
|
||||
"""
|
||||
Ensure the that the provided framework is returned.
|
||||
"""
|
||||
mock_framework = "mock_framework"
|
||||
|
||||
# Framework provided - return whatever the user provides
|
||||
result = FeaturesManager.determine_framework(self.test_model, mock_framework)
|
||||
self.assertEqual(result, mock_framework)
|
||||
|
||||
# Local checkpoint and framework provided - return provided framework
|
||||
# PyTorch checkpoint
|
||||
with TemporaryDirectory() as local_pt_ckpt:
|
||||
self._setup_pt_ckpt(local_pt_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework)
|
||||
self.assertEqual(result, mock_framework)
|
||||
|
||||
# TensorFlow checkpoint
|
||||
with TemporaryDirectory() as local_tf_ckpt:
|
||||
self._setup_tf_ckpt(local_tf_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework)
|
||||
self.assertEqual(result, mock_framework)
|
||||
|
||||
def test_checkpoint_provided(self):
|
||||
"""
|
||||
Ensure that the determined framework is the one used for the local checkpoint.
|
||||
|
||||
For the functionality to execute, local checkpoints are provided but framework is not.
|
||||
"""
|
||||
# PyTorch checkpoint
|
||||
with TemporaryDirectory() as local_pt_ckpt:
|
||||
self._setup_pt_ckpt(local_pt_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_pt_ckpt)
|
||||
self.assertEqual(result, self.framework_pt)
|
||||
|
||||
# TensorFlow checkpoint
|
||||
with TemporaryDirectory() as local_tf_ckpt:
|
||||
self._setup_tf_ckpt(local_tf_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_tf_ckpt)
|
||||
self.assertEqual(result, self.framework_tf)
|
||||
|
||||
# Invalid local checkpoint
|
||||
with TemporaryDirectory() as local_invalid_ckpt:
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
result = FeaturesManager.determine_framework(local_invalid_ckpt)
|
||||
|
||||
def test_from_environment(self):
|
||||
"""
|
||||
Ensure that the determined framework is the one available in the environment.
|
||||
|
||||
For the functionality to execute, framework and local checkpoints are not provided.
|
||||
"""
|
||||
# Framework not provided, hub model is used (no local checkpoint directory)
|
||||
# TensorFlow not in environment -> use PyTorch
|
||||
mock_tf_available = MagicMock(return_value=False)
|
||||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
||||
self.assertEqual(result, self.framework_pt)
|
||||
|
||||
# PyTorch not in environment -> use TensorFlow
|
||||
mock_torch_available = MagicMock(return_value=False)
|
||||
with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
||||
self.assertEqual(result, self.framework_tf)
|
||||
|
||||
# Both in environment -> use PyTorch
|
||||
mock_tf_available = MagicMock(return_value=True)
|
||||
mock_torch_available = MagicMock(return_value=True)
|
||||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
|
||||
"transformers.onnx.features.is_torch_available", mock_torch_available
|
||||
):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
||||
self.assertEqual(result, self.framework_pt)
|
||||
|
||||
# Both not in environment -> raise error
|
||||
mock_tf_available = MagicMock(return_value=False)
|
||||
mock_torch_available = MagicMock(return_value=False)
|
||||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
|
||||
"transformers.onnx.features.is_torch_available", mock_torch_available
|
||||
):
|
||||
with self.assertRaises(EnvironmentError):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
||||
Reference in New Issue
Block a user