Instantiate model only once in pipeline (#10888)
* Instantiate model only once in pipeline * Remove documentation of deprecated method * Add FutureWarning * Update src/transformers/pipelines/base.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -47,6 +47,4 @@ Data format
|
|||||||
Utilities
|
Utilities
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autofunction:: transformers.pipelines.get_framework
|
|
||||||
|
|
||||||
.. autoclass:: transformers.pipelines.PipelineException
|
.. autoclass:: transformers.pipelines.PipelineException
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from .base import (
|
|||||||
PipelineDataFormat,
|
PipelineDataFormat,
|
||||||
PipelineException,
|
PipelineException,
|
||||||
get_default_model,
|
get_default_model,
|
||||||
get_framework,
|
infer_framework_from_model,
|
||||||
)
|
)
|
||||||
from .conversational import Conversation, ConversationalPipeline
|
from .conversational import Conversation, ConversationalPipeline
|
||||||
from .feature_extraction import FeatureExtractionPipeline
|
from .feature_extraction import FeatureExtractionPipeline
|
||||||
@@ -341,10 +341,6 @@ def pipeline(
|
|||||||
# At that point framework might still be undetermined
|
# At that point framework might still be undetermined
|
||||||
model = get_default_model(targeted_task, framework, task_options)
|
model = get_default_model(targeted_task, framework, task_options)
|
||||||
|
|
||||||
framework = framework or get_framework(model)
|
|
||||||
|
|
||||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
|
||||||
|
|
||||||
# Try to infer tokenizer from model or config name (if provided as str)
|
# Try to infer tokenizer from model or config name (if provided as str)
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
@@ -365,6 +361,12 @@ def pipeline(
|
|||||||
elif isinstance(config, str):
|
elif isinstance(config, str):
|
||||||
modelcard = config
|
modelcard = config
|
||||||
|
|
||||||
|
# Infer the framework form the model
|
||||||
|
if framework is None:
|
||||||
|
framework, model = infer_framework_from_model(model, targeted_task, revision=revision)
|
||||||
|
|
||||||
|
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
||||||
|
|
||||||
# Instantiate tokenizer if needed
|
# Instantiate tokenizer if needed
|
||||||
if isinstance(tokenizer, (str, tuple)):
|
if isinstance(tokenizer, (str, tuple)):
|
||||||
if isinstance(tokenizer, tuple):
|
if isinstance(tokenizer, tuple):
|
||||||
@@ -406,14 +408,13 @@ def pipeline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
|
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
|
||||||
|
|
||||||
if task == "translation" and model.config.task_specific_params:
|
if task == "translation" and model.config.task_specific_params:
|
||||||
for key in model.config.task_specific_params:
|
for key in model.config.task_specific_params:
|
||||||
if key.startswith("translation"):
|
if key.startswith("translation"):
|
||||||
task = key
|
task = key
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format(
|
f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"',
|
||||||
task
|
|
||||||
),
|
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from os.path import abspath, exists
|
from os.path import abspath, exists
|
||||||
@@ -46,6 +47,55 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_framework_from_model(model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
|
||||||
|
|
||||||
|
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
|
||||||
|
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
|
||||||
|
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
|
||||||
|
|
||||||
|
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
|
||||||
|
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
|
||||||
|
from.
|
||||||
|
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
|
||||||
|
A mapping framework to class.
|
||||||
|
revision (:obj:`str`, `optional`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Tuple`: A tuple framework, model.
|
||||||
|
"""
|
||||||
|
if not is_tf_available() and not is_torch_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||||
|
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||||
|
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||||
|
)
|
||||||
|
if isinstance(model, str):
|
||||||
|
if is_torch_available() and not is_tf_available():
|
||||||
|
model_class = model_classes.get("pt", AutoModel)
|
||||||
|
model = model_class.from_pretrained(model, revision=revision)
|
||||||
|
elif is_tf_available() and not is_torch_available():
|
||||||
|
model_class = model_classes.get("tf", TFAutoModel)
|
||||||
|
model = model_class.from_pretrained(model, revision=revision)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_class = model_classes.get("pt", AutoModel)
|
||||||
|
model = model_class.from_pretrained(model, revision=revision)
|
||||||
|
except OSError:
|
||||||
|
model_class = model_classes.get("tf", TFAutoModel)
|
||||||
|
model = model_class.from_pretrained(model, revision=revision)
|
||||||
|
|
||||||
|
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||||
|
return framework, model
|
||||||
|
|
||||||
|
|
||||||
def get_framework(model, revision: Optional[str] = None):
|
def get_framework(model, revision: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Select framework (TensorFlow or PyTorch) to use.
|
Select framework (TensorFlow or PyTorch) to use.
|
||||||
@@ -55,6 +105,10 @@ def get_framework(model, revision: Optional[str] = None):
|
|||||||
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
|
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
|
||||||
the model name). If no specific model is provided, defaults to using PyTorch.
|
the model name). If no specific model is provided, defaults to using PyTorch.
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||||
@@ -474,7 +528,7 @@ class Pipeline(_ScikitCompat):
|
|||||||
):
|
):
|
||||||
|
|
||||||
if framework is None:
|
if framework is None:
|
||||||
framework = get_framework(model)
|
framework = infer_framework_from_model(model)
|
||||||
|
|
||||||
self.task = task
|
self.task = task
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
Reference in New Issue
Block a user