Extend pipelines for automodel tupels (#12025)

* fix_torch_device_generate_test

* remove @

* finish

* refactor

* add test

* fix test

* Attempt at simplification.

* Small fix.

* Fixing non existing AutoModel for TF.

* Naming.

* Remove extra condition.

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Nicolas Patry
2021-06-07 08:41:27 -07:00
committed by GitHub
parent f8bd8c6c7e
commit 2056f26e85
3 changed files with 185 additions and 87 deletions

View File

@@ -37,7 +37,7 @@ from .base import (
PipelineDataFormat, PipelineDataFormat,
PipelineException, PipelineException,
get_default_model, get_default_model,
infer_framework_from_model, infer_framework_load_model,
) )
from .conversational import Conversation, ConversationalPipeline from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline from .feature_extraction import FeatureExtractionPipeline
@@ -110,14 +110,14 @@ TASK_ALIASES = {
SUPPORTED_TASKS = { SUPPORTED_TASKS = {
"feature-extraction": { "feature-extraction": {
"impl": FeatureExtractionPipeline, "impl": FeatureExtractionPipeline,
"tf": TFAutoModel if is_tf_available() else None, "tf": (TFAutoModel,) if is_tf_available() else (),
"pt": AutoModel if is_torch_available() else None, "pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}}, "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
}, },
"text-classification": { "text-classification": {
"impl": TextClassificationPipeline, "impl": TextClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None, "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": AutoModelForSequenceClassification if is_torch_available() else None, "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": { "default": {
"model": { "model": {
"pt": "distilbert-base-uncased-finetuned-sst-2-english", "pt": "distilbert-base-uncased-finetuned-sst-2-english",
@@ -127,8 +127,8 @@ SUPPORTED_TASKS = {
}, },
"token-classification": { "token-classification": {
"impl": TokenClassificationPipeline, "impl": TokenClassificationPipeline,
"tf": TFAutoModelForTokenClassification if is_tf_available() else None, "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
"pt": AutoModelForTokenClassification if is_torch_available() else None, "pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": { "default": {
"model": { "model": {
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english", "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
@@ -138,16 +138,16 @@ SUPPORTED_TASKS = {
}, },
"question-answering": { "question-answering": {
"impl": QuestionAnsweringPipeline, "impl": QuestionAnsweringPipeline,
"tf": TFAutoModelForQuestionAnswering if is_tf_available() else None, "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": AutoModelForQuestionAnswering if is_torch_available() else None, "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": { "default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"}, "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
}, },
}, },
"table-question-answering": { "table-question-answering": {
"impl": TableQuestionAnsweringPipeline, "impl": TableQuestionAnsweringPipeline,
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None, "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
"tf": None, "tf": (),
"default": { "default": {
"model": { "model": {
"pt": "google/tapas-base-finetuned-wtq", "pt": "google/tapas-base-finetuned-wtq",
@@ -158,21 +158,21 @@ SUPPORTED_TASKS = {
}, },
"fill-mask": { "fill-mask": {
"impl": FillMaskPipeline, "impl": FillMaskPipeline,
"tf": TFAutoModelForMaskedLM if is_tf_available() else None, "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": AutoModelForMaskedLM if is_torch_available() else None, "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
}, },
"summarization": { "summarization": {
"impl": SummarizationPipeline, "impl": SummarizationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
}, },
# This task is a special case as it's parametrized by SRC, TGT languages. # This task is a special case as it's parametrized by SRC, TGT languages.
"translation": { "translation": {
"impl": TranslationPipeline, "impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": { "default": {
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
@@ -181,20 +181,20 @@ SUPPORTED_TASKS = {
}, },
"text2text-generation": { "text2text-generation": {
"impl": Text2TextGenerationPipeline, "impl": Text2TextGenerationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
}, },
"text-generation": { "text-generation": {
"impl": TextGenerationPipeline, "impl": TextGenerationPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None, "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": AutoModelForCausalLM if is_torch_available() else None, "pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
}, },
"zero-shot-classification": { "zero-shot-classification": {
"impl": ZeroShotClassificationPipeline, "impl": ZeroShotClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None, "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": AutoModelForSequenceClassification if is_torch_available() else None, "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": { "default": {
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
@@ -203,14 +203,14 @@ SUPPORTED_TASKS = {
}, },
"conversational": { "conversational": {
"impl": ConversationalPipeline, "impl": ConversationalPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": AutoModelForCausalLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
}, },
"image-classification": { "image-classification": {
"impl": ImageClassificationPipeline, "impl": ImageClassificationPipeline,
"tf": None, "tf": (),
"pt": AutoModelForImageClassification if is_torch_available() else None, "pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}}, "default": {"model": {"pt": "google/vit-base-patch16-224"}},
}, },
} }
@@ -379,53 +379,35 @@ def pipeline(
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> pipeline('ner', model=model, tokenizer=tokenizer) >>> pipeline('ner', model=model, tokenizer=tokenizer)
""" """
# Retrieve the task # Retrieve the task
targeted_task, task_options = check_task(task) targeted_task, task_options = check_task(task)
task_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided # Use default model/config/tokenizer for the task if no model is provided
if model is None: if model is None:
# At that point framework might still be undetermined # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options) model = get_default_model(targeted_task, framework, task_options)
# Config is the primordial information item.
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
model_name = model if isinstance(model, str) else None model_name = model if isinstance(model, str) else None
# Infer the framework form the model
if framework is None:
framework, model = infer_framework_from_model(model, targeted_task, revision=revision, task=task)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Instantiate config if needed # Infer the framework from the model
if isinstance(config, str): # Forced if framework already defined, inferred if it's None
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) # Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
# Instantiate model if needed framework, model = infer_framework_load_model(
if isinstance(model, str): model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
# Handle transparent TF/PT model conversion )
if framework == "pt" and model.endswith(".h5"):
model_kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
model_kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
if model_class is None:
raise ValueError(
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
)
model = model_class.from_pretrained(
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
)
model_config = model.config model_config = model.config

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import csv import csv
import importlib
import json import json
import os import os
import pickle import pickle
@@ -21,11 +22,12 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from os.path import abspath, exists from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..utils import logging from ..utils import logging
@@ -48,8 +50,108 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def infer_framework_load_model(
model,
config: AutoConfig,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
config (:class:`~transformers.AutoConfig`):
The config associated with the model to help using the correct class
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
The task defining which pipeline will be returned.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
if isinstance(model, str):
model_kwargs["_from_pipeline"] = task
class_tuple = ()
look_pt = is_torch_available() and framework in {"pt", None}
look_tf = is_tf_available() and framework in {"tf", None}
if model_classes:
if look_pt:
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
if look_tf:
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
if config.architectures:
classes = []
for architecture in config.architectures:
transformers_module = importlib.import_module("transformers")
if look_tf:
_class = getattr(transformers_module, architecture, None)
if _class is not None:
classes.append(_class)
if look_pt:
_class = getattr(transformers_module, f"TF{architecture}", None)
if _class is not None:
classes.append(_class)
class_tuple = class_tuple + tuple(classes)
if len(class_tuple) == 0:
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
for model_class in class_tuple:
kwargs = model_kwargs.copy()
if framework == "pt" and model.endswith(".h5"):
kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
try:
model = model_class.from_pretrained(model, **kwargs)
# Stop loading on the first successful load.
break
except (OSError, ValueError):
continue
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model
def infer_framework_from_model( def infer_framework_from_model(
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs model,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
): ):
""" """
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
@@ -75,30 +177,13 @@ def infer_framework_from_model(
Returns: Returns:
:obj:`Tuple`: A tuple framework, model. :obj:`Tuple`: A tuple framework, model.
""" """
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
if isinstance(model, str): if isinstance(model, str):
model_kwargs["_from_pipeline"] = task config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
if is_torch_available() and not is_tf_available(): else:
model_class = model_classes.get("pt", AutoModel) config = model.config
model = model_class.from_pretrained(model, **model_kwargs) return infer_framework_load_model(
elif is_tf_available() and not is_torch_available(): model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
model_class = model_classes.get("tf", TFAutoModel) )
model = model_class.from_pretrained(model, **model_kwargs)
else:
try:
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
except OSError:
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model
def get_framework(model, revision: Optional[str] = None): def get_framework(model, revision: Optional[str] = None):
@@ -534,7 +619,7 @@ class Pipeline(_ScikitCompat):
): ):
if framework is None: if framework is None:
framework, model = infer_framework_from_model(model) framework, model = infer_framework_load_model(model, config=model.config)
self.task = task self.task = task
self.model = model self.model = model

View File

@@ -18,6 +18,8 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallTokenizer,
Conversation, Conversation,
ConversationalPipeline, ConversationalPipeline,
is_torch_available, is_torch_available,
@@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.") self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
self.assertEqual(result[1].past_user_inputs[1], "What's your name?") self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.") self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
@require_torch
@slow
def test_from_pipeline_conversation(self):
model_id = "facebook/blenderbot_small-90M"
# from model id
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)
# from model object
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)
conversation = Conversation("My name is Sarah and I live in London")
conversation_copy = Conversation("My name is Sarah and I live in London")
result_model_id = conversation_agent_from_model_id([conversation])
result_model = conversation_agent_from_model([conversation_copy])
# check for equality
self.assertEqual(
result_model_id.generated_responses[0],
"hi sarah, i live in london as well. do you have any plans for the weekend?",
)
self.assertEqual(
result_model_id.generated_responses[0],
result_model.generated_responses[0],
)