From 06a6fea7820dc3e89d09430a49bce1c72b173647 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 29 Mar 2021 10:39:14 -0400 Subject: [PATCH] Instantiate model only once in pipeline (#10888) * Instantiate model only once in pipeline * Remove documentation of deprecated method * Add FutureWarning * Update src/transformers/pipelines/base.py Co-authored-by: Lysandre Debut Co-authored-by: Lysandre Debut --- docs/source/internal/pipelines_utils.rst | 2 - src/transformers/pipelines/__init__.py | 33 +++++++------- src/transformers/pipelines/base.py | 56 +++++++++++++++++++++++- 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/docs/source/internal/pipelines_utils.rst b/docs/source/internal/pipelines_utils.rst index 5d93defafd..e2181a6550 100644 --- a/docs/source/internal/pipelines_utils.rst +++ b/docs/source/internal/pipelines_utils.rst @@ -47,6 +47,4 @@ Data format Utilities ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: transformers.pipelines.get_framework - .. autoclass:: transformers.pipelines.PipelineException diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 762994fa86..43b1549627 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -34,7 +34,7 @@ from .base import ( PipelineDataFormat, PipelineException, get_default_model, - get_framework, + infer_framework_from_model, ) from .conversational import Conversation, ConversationalPipeline from .feature_extraction import FeatureExtractionPipeline @@ -341,10 +341,6 @@ def pipeline( # At that point framework might still be undetermined model = get_default_model(targeted_task, framework, task_options) - framework = framework or get_framework(model) - - task_class, model_class = targeted_task["impl"], targeted_task[framework] - # Try to infer tokenizer from model or config name (if provided as str) if tokenizer is None: if isinstance(model, str): @@ -365,6 +361,12 @@ def pipeline( elif isinstance(config, str): modelcard = config + # Infer the framework form the model + if framework is None: + framework, model = infer_framework_from_model(model, targeted_task, revision=revision) + + task_class, model_class = targeted_task["impl"], targeted_task[framework] + # Instantiate tokenizer if needed if isinstance(tokenizer, (str, tuple)): if isinstance(tokenizer, tuple): @@ -406,16 +408,15 @@ def pipeline( ) model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs) - if task == "translation" and model.config.task_specific_params: - for key in model.config.task_specific_params: - if key.startswith("translation"): - task = key - warnings.warn( - '"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format( - task - ), - UserWarning, - ) - break + + if task == "translation" and model.config.task_specific_params: + for key in model.config.task_specific_params: + if key.startswith("translation"): + task = key + warnings.warn( + f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"', + UserWarning, + ) + break return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 124f2e290e..01d3699c6f 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -17,6 +17,7 @@ import json import os import pickle import sys +import warnings from abc import ABC, abstractmethod from contextlib import contextmanager from os.path import abspath, exists @@ -46,6 +47,55 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) +def infer_framework_from_model(model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None): + """ + Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). + + If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise + :obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`. + Since we don't want to instantiate the model twice, this model is returned for use by the pipeline. + + If both frameworks are installed and available for :obj:`model`, PyTorch is selected. + + Args: + model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`): + The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok + from. + model_classes (dictionary :obj:`str` to :obj:`type`, `optional`): + A mapping framework to class. + revision (:obj:`str`, `optional`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + + Returns: + :obj:`Tuple`: A tuple framework, model. + """ + if not is_tf_available() and not is_torch_available(): + raise RuntimeError( + "At least one of TensorFlow 2.0 or PyTorch should be installed. " + "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " + "To install PyTorch, read the instructions at https://pytorch.org/." + ) + if isinstance(model, str): + if is_torch_available() and not is_tf_available(): + model_class = model_classes.get("pt", AutoModel) + model = model_class.from_pretrained(model, revision=revision) + elif is_tf_available() and not is_torch_available(): + model_class = model_classes.get("tf", TFAutoModel) + model = model_class.from_pretrained(model, revision=revision) + else: + try: + model_class = model_classes.get("pt", AutoModel) + model = model_class.from_pretrained(model, revision=revision) + except OSError: + model_class = model_classes.get("tf", TFAutoModel) + model = model_class.from_pretrained(model, revision=revision) + + framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" + return framework, model + + def get_framework(model, revision: Optional[str] = None): """ Select framework (TensorFlow or PyTorch) to use. @@ -55,6 +105,10 @@ def get_framework(model, revision: Optional[str] = None): If both frameworks are installed, picks the one corresponding to the model passed (either a model class or the model name). If no specific model is provided, defaults to using PyTorch. """ + warnings.warn( + "`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.", + FutureWarning, + ) if not is_tf_available() and not is_torch_available(): raise RuntimeError( "At least one of TensorFlow 2.0 or PyTorch should be installed. " @@ -474,7 +528,7 @@ class Pipeline(_ScikitCompat): ): if framework is None: - framework = get_framework(model) + framework = infer_framework_from_model(model) self.task = task self.model = model