diff --git a/.circleci/config.yml b/.circleci/config.yml index 3b895d0dd1..a8fcedad3f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -878,7 +878,7 @@ jobs: - v0.5-torch-{{ checksum "setup.py" }} - v0.5-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[torch,testing,sentencepiece,onnxruntime,vision,rjieba] + - run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision,rjieba] - save_cache: key: v0.5-onnx-{{ checksum "setup.py" }} paths: @@ -912,7 +912,7 @@ jobs: - v0.5-torch-{{ checksum "setup.py" }} - v0.5-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[torch,testing,sentencepiece,onnxruntime,vision] + - run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision] - save_cache: key: v0.5-onnx-{{ checksum "setup.py" }} paths: diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index 6d665b3556..55ad5f54c9 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -38,7 +38,15 @@ def main(): "--atol", type=float, default=None, help="Absolute difference tolerence when validating the model." ) parser.add_argument( - "--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export." + "--framework", + type=str, + choices=["pt", "tf"], + default=None, + help=( + "The framework to use for the ONNX export." + " If not provided, will attempt to use the local checkpoint's original framework" + " or what is available in the environment." + ), ) parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.") diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 3f18c36983..eb57df1c96 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -1,10 +1,11 @@ +import os from functools import partial, reduce from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union import transformers from .. import PretrainedConfig, is_tf_available, is_torch_available -from ..utils import logging +from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging from .config import OnnxConfig @@ -566,9 +567,59 @@ class FeaturesManager: ) return task_to_automodel[task] + @staticmethod + def determine_framework(model: str, framework: str = None) -> str: + """ + Determines the framework to use for the export. + + The priority is in the following order: + 1. User input via `framework`. + 2. If local checkpoint is provided, use the same framework as the checkpoint. + 3. Available framework in environment, with priority given to PyTorch + + Args: + model (`str`): + The name of the model to export. + framework (`str`, *optional*, defaults to `None`): + The framework to use for the export. See above for priority if none provided. + + Returns: + The framework to use for the export. + + """ + if framework is not None: + return framework + + framework_map = {"pt": "PyTorch", "tf": "TensorFlow"} + exporter_map = {"pt": "torch", "tf": "tf2onnx"} + + if os.path.isdir(model): + if os.path.isfile(os.path.join(model, WEIGHTS_NAME)): + framework = "pt" + elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)): + framework = "tf" + else: + raise FileNotFoundError( + "Cannot determine framework from given checkpoint location." + f" There should be a {WEIGHTS_NAME} for PyTorch" + f" or {TF2_WEIGHTS_NAME} for TensorFlow." + ) + logger.info(f"Local {framework_map[framework]} model found.") + else: + if is_torch_available(): + framework = "pt" + elif is_tf_available(): + framework = "tf" + else: + raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.") + + logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.") + + return framework + @staticmethod def get_model_from_feature( - feature: str, model: str, framework: str = "pt", cache_dir: str = None + feature: str, model: str, framework: str = None, cache_dir: str = None ) -> Union["PreTrainedModel", "TFPreTrainedModel"]: """ Attempts to retrieve a model from a model's name and the feature to be enabled. @@ -578,20 +629,24 @@ class FeaturesManager: The feature required. model (`str`): The name of the model to export. - framework (`str`, *optional*, defaults to `"pt"`): - The framework to use for the export. + framework (`str`, *optional*, defaults to `None`): + The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should + none be provided. Returns: The instance of the model. """ + framework = FeaturesManager.determine_framework(model, framework) model_class = FeaturesManager.get_model_class_for_feature(feature, framework) try: model = model_class.from_pretrained(model, cache_dir=cache_dir) except OSError: if framework == "pt": + logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.") model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir) else: + logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.") model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir) return model diff --git a/tests/onnx/test_features.py b/tests/onnx/test_features.py new file mode 100644 index 0000000000..4590ff0cc8 --- /dev/null +++ b/tests/onnx/test_features.py @@ -0,0 +1,111 @@ +from tempfile import TemporaryDirectory +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from transformers import AutoModel, TFAutoModel +from transformers.onnx import FeaturesManager +from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch + + +@require_torch +@require_tf +class DetermineFrameworkTest(TestCase): + """ + Test `FeaturesManager.determine_framework` + """ + + def setUp(self): + self.test_model = SMALL_MODEL_IDENTIFIER + self.framework_pt = "pt" + self.framework_tf = "tf" + + def _setup_pt_ckpt(self, save_dir): + model_pt = AutoModel.from_pretrained(self.test_model) + model_pt.save_pretrained(save_dir) + + def _setup_tf_ckpt(self, save_dir): + model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True) + model_tf.save_pretrained(save_dir) + + def test_framework_provided(self): + """ + Ensure the that the provided framework is returned. + """ + mock_framework = "mock_framework" + + # Framework provided - return whatever the user provides + result = FeaturesManager.determine_framework(self.test_model, mock_framework) + self.assertEqual(result, mock_framework) + + # Local checkpoint and framework provided - return provided framework + # PyTorch checkpoint + with TemporaryDirectory() as local_pt_ckpt: + self._setup_pt_ckpt(local_pt_ckpt) + result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework) + self.assertEqual(result, mock_framework) + + # TensorFlow checkpoint + with TemporaryDirectory() as local_tf_ckpt: + self._setup_tf_ckpt(local_tf_ckpt) + result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework) + self.assertEqual(result, mock_framework) + + def test_checkpoint_provided(self): + """ + Ensure that the determined framework is the one used for the local checkpoint. + + For the functionality to execute, local checkpoints are provided but framework is not. + """ + # PyTorch checkpoint + with TemporaryDirectory() as local_pt_ckpt: + self._setup_pt_ckpt(local_pt_ckpt) + result = FeaturesManager.determine_framework(local_pt_ckpt) + self.assertEqual(result, self.framework_pt) + + # TensorFlow checkpoint + with TemporaryDirectory() as local_tf_ckpt: + self._setup_tf_ckpt(local_tf_ckpt) + result = FeaturesManager.determine_framework(local_tf_ckpt) + self.assertEqual(result, self.framework_tf) + + # Invalid local checkpoint + with TemporaryDirectory() as local_invalid_ckpt: + with self.assertRaises(FileNotFoundError): + result = FeaturesManager.determine_framework(local_invalid_ckpt) + + def test_from_environment(self): + """ + Ensure that the determined framework is the one available in the environment. + + For the functionality to execute, framework and local checkpoints are not provided. + """ + # Framework not provided, hub model is used (no local checkpoint directory) + # TensorFlow not in environment -> use PyTorch + mock_tf_available = MagicMock(return_value=False) + with patch("transformers.onnx.features.is_tf_available", mock_tf_available): + result = FeaturesManager.determine_framework(self.test_model) + self.assertEqual(result, self.framework_pt) + + # PyTorch not in environment -> use TensorFlow + mock_torch_available = MagicMock(return_value=False) + with patch("transformers.onnx.features.is_torch_available", mock_torch_available): + result = FeaturesManager.determine_framework(self.test_model) + self.assertEqual(result, self.framework_tf) + + # Both in environment -> use PyTorch + mock_tf_available = MagicMock(return_value=True) + mock_torch_available = MagicMock(return_value=True) + with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( + "transformers.onnx.features.is_torch_available", mock_torch_available + ): + result = FeaturesManager.determine_framework(self.test_model) + self.assertEqual(result, self.framework_pt) + + # Both not in environment -> raise error + mock_tf_available = MagicMock(return_value=False) + mock_torch_available = MagicMock(return_value=False) + with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( + "transformers.onnx.features.is_torch_available", mock_torch_available + ): + with self.assertRaises(EnvironmentError): + result = FeaturesManager.determine_framework(self.test_model) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index ba122f43f8..7a645bba12 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -434,7 +434,7 @@ def module_to_test_file(module_fname): return "tests/utils/test_cli.py" # Special case for onnx submodules elif len(splits) >= 2 and splits[-2] == "onnx": - return ["tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"] + return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"] # Special case for utils (not the one in src/transformers, the ones at the root of the repo). elif len(splits) > 0 and splits[0] == "utils": default_test_file = f"tests/utils/test_utils_{module_name}"