Extend pipelines for automodel tupels (#12025)
* fix_torch_device_generate_test * remove @ * finish * refactor * add test * fix test * Attempt at simplification. * Small fix. * Fixing non existing AutoModel for TF. * Naming. * Remove extra condition. Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -37,7 +37,7 @@ from .base import (
|
||||
PipelineDataFormat,
|
||||
PipelineException,
|
||||
get_default_model,
|
||||
infer_framework_from_model,
|
||||
infer_framework_load_model,
|
||||
)
|
||||
from .conversational import Conversation, ConversationalPipeline
|
||||
from .feature_extraction import FeatureExtractionPipeline
|
||||
@@ -110,14 +110,14 @@ TASK_ALIASES = {
|
||||
SUPPORTED_TASKS = {
|
||||
"feature-extraction": {
|
||||
"impl": FeatureExtractionPipeline,
|
||||
"tf": TFAutoModel if is_tf_available() else None,
|
||||
"pt": AutoModel if is_torch_available() else None,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
||||
},
|
||||
"text-classification": {
|
||||
"impl": TextClassificationPipeline,
|
||||
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||
@@ -127,8 +127,8 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"token-classification": {
|
||||
"impl": TokenClassificationPipeline,
|
||||
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
|
||||
"pt": AutoModelForTokenClassification if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
@@ -138,16 +138,16 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"question-answering": {
|
||||
"impl": QuestionAnsweringPipeline,
|
||||
"tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
|
||||
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
||||
},
|
||||
},
|
||||
"table-question-answering": {
|
||||
"impl": TableQuestionAnsweringPipeline,
|
||||
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None,
|
||||
"tf": None,
|
||||
"pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
|
||||
"tf": (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "google/tapas-base-finetuned-wtq",
|
||||
@@ -158,21 +158,21 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"fill-mask": {
|
||||
"impl": FillMaskPipeline,
|
||||
"tf": TFAutoModelForMaskedLM if is_tf_available() else None,
|
||||
"pt": AutoModelForMaskedLM if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
||||
},
|
||||
"summarization": {
|
||||
"impl": SummarizationPipeline,
|
||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
|
||||
},
|
||||
# This task is a special case as it's parametrized by SRC, TGT languages.
|
||||
"translation": {
|
||||
"impl": TranslationPipeline,
|
||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {
|
||||
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
@@ -181,20 +181,20 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"text2text-generation": {
|
||||
"impl": Text2TextGenerationPipeline,
|
||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
"text-generation": {
|
||||
"impl": TextGenerationPipeline,
|
||||
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
|
||||
"pt": AutoModelForCausalLM if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
||||
},
|
||||
"zero-shot-classification": {
|
||||
"impl": ZeroShotClassificationPipeline,
|
||||
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
@@ -203,14 +203,14 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"conversational": {
|
||||
"impl": ConversationalPipeline,
|
||||
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
|
||||
"pt": AutoModelForCausalLM if is_torch_available() else None,
|
||||
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
||||
},
|
||||
"image-classification": {
|
||||
"impl": ImageClassificationPipeline,
|
||||
"tf": None,
|
||||
"pt": AutoModelForImageClassification if is_torch_available() else None,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||
},
|
||||
}
|
||||
@@ -379,52 +379,34 @@ def pipeline(
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||
>>> pipeline('ner', model=model, tokenizer=tokenizer)
|
||||
"""
|
||||
|
||||
# Retrieve the task
|
||||
targeted_task, task_options = check_task(task)
|
||||
task_class = targeted_task["impl"]
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
# At that point framework might still be undetermined
|
||||
model = get_default_model(targeted_task, framework, task_options)
|
||||
|
||||
# Config is the primordial information item.
|
||||
# Instantiate config if needed
|
||||
if isinstance(config, str):
|
||||
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
|
||||
elif config is None and isinstance(model, str):
|
||||
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
|
||||
|
||||
model_name = model if isinstance(model, str) else None
|
||||
|
||||
# Infer the framework form the model
|
||||
if framework is None:
|
||||
framework, model = infer_framework_from_model(model, targeted_task, revision=revision, task=task)
|
||||
|
||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
||||
|
||||
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
||||
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
||||
|
||||
# Instantiate config if needed
|
||||
if isinstance(config, str):
|
||||
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
|
||||
|
||||
# Instantiate model if needed
|
||||
if isinstance(model, str):
|
||||
# Handle transparent TF/PT model conversion
|
||||
if framework == "pt" and model.endswith(".h5"):
|
||||
model_kwargs["from_tf"] = True
|
||||
logger.warning(
|
||||
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
|
||||
"Trying to load the model with PyTorch."
|
||||
)
|
||||
elif framework == "tf" and model.endswith(".bin"):
|
||||
model_kwargs["from_pt"] = True
|
||||
logger.warning(
|
||||
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
||||
"Trying to load the model with Tensorflow."
|
||||
)
|
||||
|
||||
if model_class is None:
|
||||
raise ValueError(
|
||||
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
|
||||
)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
|
||||
# Infer the framework from the model
|
||||
# Forced if framework already defined, inferred if it's None
|
||||
# Will load the correct model if possible
|
||||
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
||||
framework, model = infer_framework_load_model(
|
||||
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
|
||||
)
|
||||
|
||||
model_config = model.config
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import csv
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
@@ -21,11 +22,12 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from os.path import abspath, exists
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
from ..modelcard import ModelCard
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
|
||||
from ..utils import logging
|
||||
|
||||
@@ -48,8 +50,108 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def infer_framework_load_model(
|
||||
model,
|
||||
config: AutoConfig,
|
||||
model_classes: Optional[Dict[str, Tuple[type]]] = None,
|
||||
task: Optional[str] = None,
|
||||
framework: Optional[str] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
|
||||
|
||||
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
|
||||
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
|
||||
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
|
||||
|
||||
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
|
||||
|
||||
Args:
|
||||
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
|
||||
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
|
||||
from.
|
||||
config (:class:`~transformers.AutoConfig`):
|
||||
The config associated with the model to help using the correct class
|
||||
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
|
||||
A mapping framework to class.
|
||||
task (:obj:`str`):
|
||||
The task defining which pipeline will be returned.
|
||||
model_kwargs:
|
||||
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
|
||||
**model_kwargs)` function.
|
||||
|
||||
Returns:
|
||||
:obj:`Tuple`: A tuple framework, model.
|
||||
"""
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
raise RuntimeError(
|
||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model_kwargs["_from_pipeline"] = task
|
||||
class_tuple = ()
|
||||
look_pt = is_torch_available() and framework in {"pt", None}
|
||||
look_tf = is_tf_available() and framework in {"tf", None}
|
||||
if model_classes:
|
||||
if look_pt:
|
||||
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
|
||||
if look_tf:
|
||||
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
|
||||
if config.architectures:
|
||||
classes = []
|
||||
for architecture in config.architectures:
|
||||
transformers_module = importlib.import_module("transformers")
|
||||
if look_tf:
|
||||
_class = getattr(transformers_module, architecture, None)
|
||||
if _class is not None:
|
||||
classes.append(_class)
|
||||
if look_pt:
|
||||
_class = getattr(transformers_module, f"TF{architecture}", None)
|
||||
if _class is not None:
|
||||
classes.append(_class)
|
||||
class_tuple = class_tuple + tuple(classes)
|
||||
|
||||
if len(class_tuple) == 0:
|
||||
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
|
||||
|
||||
for model_class in class_tuple:
|
||||
kwargs = model_kwargs.copy()
|
||||
if framework == "pt" and model.endswith(".h5"):
|
||||
kwargs["from_tf"] = True
|
||||
logger.warning(
|
||||
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
|
||||
"Trying to load the model with PyTorch."
|
||||
)
|
||||
elif framework == "tf" and model.endswith(".bin"):
|
||||
kwargs["from_pt"] = True
|
||||
logger.warning(
|
||||
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
||||
"Trying to load the model with Tensorflow."
|
||||
)
|
||||
|
||||
try:
|
||||
model = model_class.from_pretrained(model, **kwargs)
|
||||
# Stop loading on the first successful load.
|
||||
break
|
||||
except (OSError, ValueError):
|
||||
continue
|
||||
|
||||
if isinstance(model, str):
|
||||
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
||||
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
return framework, model
|
||||
|
||||
|
||||
def infer_framework_from_model(
|
||||
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs
|
||||
model,
|
||||
model_classes: Optional[Dict[str, Tuple[type]]] = None,
|
||||
task: Optional[str] = None,
|
||||
framework: Optional[str] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
|
||||
@@ -75,30 +177,13 @@ def infer_framework_from_model(
|
||||
Returns:
|
||||
:obj:`Tuple`: A tuple framework, model.
|
||||
"""
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
raise RuntimeError(
|
||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model_kwargs["_from_pipeline"] = task
|
||||
if is_torch_available() and not is_tf_available():
|
||||
model_class = model_classes.get("pt", AutoModel)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
elif is_tf_available() and not is_torch_available():
|
||||
model_class = model_classes.get("tf", TFAutoModel)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
|
||||
else:
|
||||
try:
|
||||
model_class = model_classes.get("pt", AutoModel)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
except OSError:
|
||||
model_class = model_classes.get("tf", TFAutoModel)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
return framework, model
|
||||
config = model.config
|
||||
return infer_framework_load_model(
|
||||
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
|
||||
)
|
||||
|
||||
|
||||
def get_framework(model, revision: Optional[str] = None):
|
||||
@@ -534,7 +619,7 @@ class Pipeline(_ScikitCompat):
|
||||
):
|
||||
|
||||
if framework is None:
|
||||
framework, model = infer_framework_from_model(model)
|
||||
framework, model = infer_framework_load_model(model, config=model.config)
|
||||
|
||||
self.task = task
|
||||
self.model = model
|
||||
|
||||
@@ -18,6 +18,8 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BlenderbotSmallForConditionalGeneration,
|
||||
BlenderbotSmallTokenizer,
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
is_torch_available,
|
||||
@@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
|
||||
self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
|
||||
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_from_pipeline_conversation(self):
|
||||
model_id = "facebook/blenderbot_small-90M"
|
||||
|
||||
# from model id
|
||||
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)
|
||||
|
||||
# from model object
|
||||
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
|
||||
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
|
||||
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation = Conversation("My name is Sarah and I live in London")
|
||||
conversation_copy = Conversation("My name is Sarah and I live in London")
|
||||
|
||||
result_model_id = conversation_agent_from_model_id([conversation])
|
||||
result_model = conversation_agent_from_model([conversation_copy])
|
||||
|
||||
# check for equality
|
||||
self.assertEqual(
|
||||
result_model_id.generated_responses[0],
|
||||
"hi sarah, i live in london as well. do you have any plans for the weekend?",
|
||||
)
|
||||
self.assertEqual(
|
||||
result_model_id.generated_responses[0],
|
||||
result_model.generated_responses[0],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user