From 2056f26e853574034e426d97e4f803b47f8c7159 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 7 Jun 2021 08:41:27 -0700 Subject: [PATCH] Extend pipelines for automodel tupels (#12025) * fix_torch_device_generate_test * remove @ * finish * refactor * add test * fix test * Attempt at simplification. * Small fix. * Fixing non existing AutoModel for TF. * Naming. * Remove extra condition. Co-authored-by: patrickvonplaten --- src/transformers/pipelines/__init__.py | 104 ++++++++----------- src/transformers/pipelines/base.py | 137 ++++++++++++++++++++----- tests/test_pipelines_conversational.py | 31 ++++++ 3 files changed, 185 insertions(+), 87 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 33f3fe12e1..ea353caa52 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -37,7 +37,7 @@ from .base import ( PipelineDataFormat, PipelineException, get_default_model, - infer_framework_from_model, + infer_framework_load_model, ) from .conversational import Conversation, ConversationalPipeline from .feature_extraction import FeatureExtractionPipeline @@ -110,14 +110,14 @@ TASK_ALIASES = { SUPPORTED_TASKS = { "feature-extraction": { "impl": FeatureExtractionPipeline, - "tf": TFAutoModel if is_tf_available() else None, - "pt": AutoModel if is_torch_available() else None, + "tf": (TFAutoModel,) if is_tf_available() else (), + "pt": (AutoModel,) if is_torch_available() else (), "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}}, }, "text-classification": { "impl": TextClassificationPipeline, - "tf": TFAutoModelForSequenceClassification if is_tf_available() else None, - "pt": AutoModelForSequenceClassification if is_torch_available() else None, + "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (), + "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), "default": { "model": { "pt": "distilbert-base-uncased-finetuned-sst-2-english", @@ -127,8 +127,8 @@ SUPPORTED_TASKS = { }, "token-classification": { "impl": TokenClassificationPipeline, - "tf": TFAutoModelForTokenClassification if is_tf_available() else None, - "pt": AutoModelForTokenClassification if is_torch_available() else None, + "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (), + "pt": (AutoModelForTokenClassification,) if is_torch_available() else (), "default": { "model": { "pt": "dbmdz/bert-large-cased-finetuned-conll03-english", @@ -138,16 +138,16 @@ SUPPORTED_TASKS = { }, "question-answering": { "impl": QuestionAnsweringPipeline, - "tf": TFAutoModelForQuestionAnswering if is_tf_available() else None, - "pt": AutoModelForQuestionAnswering if is_torch_available() else None, + "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (), + "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (), "default": { "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"}, }, }, "table-question-answering": { "impl": TableQuestionAnsweringPipeline, - "pt": AutoModelForTableQuestionAnswering if is_torch_available() else None, - "tf": None, + "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (), + "tf": (), "default": { "model": { "pt": "google/tapas-base-finetuned-wtq", @@ -158,21 +158,21 @@ SUPPORTED_TASKS = { }, "fill-mask": { "impl": FillMaskPipeline, - "tf": TFAutoModelForMaskedLM if is_tf_available() else None, - "pt": AutoModelForMaskedLM if is_torch_available() else None, + "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), + "pt": (AutoModelForMaskedLM,) if is_torch_available() else (), "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, }, "summarization": { "impl": SummarizationPipeline, - "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, - "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, + "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), + "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, }, # This task is a special case as it's parametrized by SRC, TGT languages. "translation": { "impl": TranslationPipeline, - "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, - "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, + "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), + "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": { ("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, @@ -181,20 +181,20 @@ SUPPORTED_TASKS = { }, "text2text-generation": { "impl": Text2TextGenerationPipeline, - "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, - "pt": AutoModelForSeq2SeqLM if is_torch_available() else None, + "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), + "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, "text-generation": { "impl": TextGenerationPipeline, - "tf": TFAutoModelForCausalLM if is_tf_available() else None, - "pt": AutoModelForCausalLM if is_torch_available() else None, + "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (), + "pt": (AutoModelForCausalLM,) if is_torch_available() else (), "default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, }, "zero-shot-classification": { "impl": ZeroShotClassificationPipeline, - "tf": TFAutoModelForSequenceClassification if is_tf_available() else None, - "pt": AutoModelForSequenceClassification if is_torch_available() else None, + "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (), + "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), "default": { "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, @@ -203,14 +203,14 @@ SUPPORTED_TASKS = { }, "conversational": { "impl": ConversationalPipeline, - "tf": TFAutoModelForCausalLM if is_tf_available() else None, - "pt": AutoModelForCausalLM if is_torch_available() else None, + "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (), + "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (), "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, }, "image-classification": { "impl": ImageClassificationPipeline, - "tf": None, - "pt": AutoModelForImageClassification if is_torch_available() else None, + "tf": (), + "pt": (AutoModelForImageClassification,) if is_torch_available() else (), "default": {"model": {"pt": "google/vit-base-patch16-224"}}, }, } @@ -379,53 +379,35 @@ def pipeline( >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") >>> pipeline('ner', model=model, tokenizer=tokenizer) """ + # Retrieve the task targeted_task, task_options = check_task(task) + task_class = targeted_task["impl"] # Use default model/config/tokenizer for the task if no model is provided if model is None: # At that point framework might still be undetermined model = get_default_model(targeted_task, framework, task_options) + # Config is the primordial information item. + # Instantiate config if needed + if isinstance(config, str): + config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) + elif config is None and isinstance(model, str): + config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs) + model_name = model if isinstance(model, str) else None - # Infer the framework form the model - if framework is None: - framework, model = infer_framework_from_model(model, targeted_task, revision=revision, task=task) - - task_class, model_class = targeted_task["impl"], targeted_task[framework] - # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) - # Instantiate config if needed - if isinstance(config, str): - config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) - - # Instantiate model if needed - if isinstance(model, str): - # Handle transparent TF/PT model conversion - if framework == "pt" and model.endswith(".h5"): - model_kwargs["from_tf"] = True - logger.warning( - "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. " - "Trying to load the model with PyTorch." - ) - elif framework == "tf" and model.endswith(".bin"): - model_kwargs["from_pt"] = True - logger.warning( - "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. " - "Trying to load the model with Tensorflow." - ) - - if model_class is None: - raise ValueError( - f"Pipeline using {framework} framework, but this framework is not supported by this pipeline." - ) - - model = model_class.from_pretrained( - model, config=config, revision=revision, _from_pipeline=task, **model_kwargs - ) + # Infer the framework from the model + # Forced if framework already defined, inferred if it's None + # Will load the correct model if possible + model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} + framework, model = infer_framework_load_model( + model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task + ) model_config = model.config diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 05bf389b8a..5065c56ca2 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import csv +import importlib import json import os import pickle @@ -21,11 +22,12 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager from os.path import abspath, exists -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..modelcard import ModelCard +from ..models.auto.configuration_auto import AutoConfig from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy from ..utils import logging @@ -48,8 +50,108 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) +def infer_framework_load_model( + model, + config: AutoConfig, + model_classes: Optional[Dict[str, Tuple[type]]] = None, + task: Optional[str] = None, + framework: Optional[str] = None, + **model_kwargs +): + """ + Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). + + If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise + :obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`. + Since we don't want to instantiate the model twice, this model is returned for use by the pipeline. + + If both frameworks are installed and available for :obj:`model`, PyTorch is selected. + + Args: + model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`): + The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok + from. + config (:class:`~transformers.AutoConfig`): + The config associated with the model to help using the correct class + model_classes (dictionary :obj:`str` to :obj:`type`, `optional`): + A mapping framework to class. + task (:obj:`str`): + The task defining which pipeline will be returned. + model_kwargs: + Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(..., + **model_kwargs)` function. + + Returns: + :obj:`Tuple`: A tuple framework, model. + """ + if not is_tf_available() and not is_torch_available(): + raise RuntimeError( + "At least one of TensorFlow 2.0 or PyTorch should be installed. " + "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " + "To install PyTorch, read the instructions at https://pytorch.org/." + ) + if isinstance(model, str): + model_kwargs["_from_pipeline"] = task + class_tuple = () + look_pt = is_torch_available() and framework in {"pt", None} + look_tf = is_tf_available() and framework in {"tf", None} + if model_classes: + if look_pt: + class_tuple = class_tuple + model_classes.get("pt", (AutoModel,)) + if look_tf: + class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,)) + if config.architectures: + classes = [] + for architecture in config.architectures: + transformers_module = importlib.import_module("transformers") + if look_tf: + _class = getattr(transformers_module, architecture, None) + if _class is not None: + classes.append(_class) + if look_pt: + _class = getattr(transformers_module, f"TF{architecture}", None) + if _class is not None: + classes.append(_class) + class_tuple = class_tuple + tuple(classes) + + if len(class_tuple) == 0: + raise ValueError(f"Pipeline cannot infer suitable model classes from {model}") + + for model_class in class_tuple: + kwargs = model_kwargs.copy() + if framework == "pt" and model.endswith(".h5"): + kwargs["from_tf"] = True + logger.warning( + "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. " + "Trying to load the model with PyTorch." + ) + elif framework == "tf" and model.endswith(".bin"): + kwargs["from_pt"] = True + logger.warning( + "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. " + "Trying to load the model with Tensorflow." + ) + + try: + model = model_class.from_pretrained(model, **kwargs) + # Stop loading on the first successful load. + break + except (OSError, ValueError): + continue + + if isinstance(model, str): + raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.") + + framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" + return framework, model + + def infer_framework_from_model( - model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs + model, + model_classes: Optional[Dict[str, Tuple[type]]] = None, + task: Optional[str] = None, + framework: Optional[str] = None, + **model_kwargs ): """ Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). @@ -75,30 +177,13 @@ def infer_framework_from_model( Returns: :obj:`Tuple`: A tuple framework, model. """ - if not is_tf_available() and not is_torch_available(): - raise RuntimeError( - "At least one of TensorFlow 2.0 or PyTorch should be installed. " - "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " - "To install PyTorch, read the instructions at https://pytorch.org/." - ) if isinstance(model, str): - model_kwargs["_from_pipeline"] = task - if is_torch_available() and not is_tf_available(): - model_class = model_classes.get("pt", AutoModel) - model = model_class.from_pretrained(model, **model_kwargs) - elif is_tf_available() and not is_torch_available(): - model_class = model_classes.get("tf", TFAutoModel) - model = model_class.from_pretrained(model, **model_kwargs) - else: - try: - model_class = model_classes.get("pt", AutoModel) - model = model_class.from_pretrained(model, **model_kwargs) - except OSError: - model_class = model_classes.get("tf", TFAutoModel) - model = model_class.from_pretrained(model, **model_kwargs) - - framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" - return framework, model + config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs) + else: + config = model.config + return infer_framework_load_model( + model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs + ) def get_framework(model, revision: Optional[str] = None): @@ -534,7 +619,7 @@ class Pipeline(_ScikitCompat): ): if framework is None: - framework, model = infer_framework_from_model(model) + framework, model = infer_framework_load_model(model, config=model.config) self.task = task self.model = model diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 0500f61726..89524dd3fb 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -18,6 +18,8 @@ from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, + BlenderbotSmallForConditionalGeneration, + BlenderbotSmallTokenizer, Conversation, ConversationalPipeline, is_torch_available, @@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.") self.assertEqual(result[1].past_user_inputs[1], "What's your name?") self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.") + + @require_torch + @slow + def test_from_pipeline_conversation(self): + model_id = "facebook/blenderbot_small-90M" + + # from model id + conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id) + + # from model object + model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id) + tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id) + conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer) + + conversation = Conversation("My name is Sarah and I live in London") + conversation_copy = Conversation("My name is Sarah and I live in London") + + result_model_id = conversation_agent_from_model_id([conversation]) + result_model = conversation_agent_from_model([conversation_copy]) + + # check for equality + self.assertEqual( + result_model_id.generated_responses[0], + "hi sarah, i live in london as well. do you have any plans for the weekend?", + ) + self.assertEqual( + result_model_id.generated_responses[0], + result_model.generated_responses[0], + )