Extend pipelines for automodel tupels (#12025)
* fix_torch_device_generate_test * remove @ * finish * refactor * add test * fix test * Attempt at simplification. * Small fix. * Fixing non existing AutoModel for TF. * Naming. * Remove extra condition. Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -37,7 +37,7 @@ from .base import (
|
|||||||
PipelineDataFormat,
|
PipelineDataFormat,
|
||||||
PipelineException,
|
PipelineException,
|
||||||
get_default_model,
|
get_default_model,
|
||||||
infer_framework_from_model,
|
infer_framework_load_model,
|
||||||
)
|
)
|
||||||
from .conversational import Conversation, ConversationalPipeline
|
from .conversational import Conversation, ConversationalPipeline
|
||||||
from .feature_extraction import FeatureExtractionPipeline
|
from .feature_extraction import FeatureExtractionPipeline
|
||||||
@@ -110,14 +110,14 @@ TASK_ALIASES = {
|
|||||||
SUPPORTED_TASKS = {
|
SUPPORTED_TASKS = {
|
||||||
"feature-extraction": {
|
"feature-extraction": {
|
||||||
"impl": FeatureExtractionPipeline,
|
"impl": FeatureExtractionPipeline,
|
||||||
"tf": TFAutoModel if is_tf_available() else None,
|
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||||
"pt": AutoModel if is_torch_available() else None,
|
"pt": (AutoModel,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
||||||
},
|
},
|
||||||
"text-classification": {
|
"text-classification": {
|
||||||
"impl": TextClassificationPipeline,
|
"impl": TextClassificationPipeline,
|
||||||
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
|
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
|
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
||||||
"default": {
|
"default": {
|
||||||
"model": {
|
"model": {
|
||||||
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
|
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||||
@@ -127,8 +127,8 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"token-classification": {
|
"token-classification": {
|
||||||
"impl": TokenClassificationPipeline,
|
"impl": TokenClassificationPipeline,
|
||||||
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
|
"tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForTokenClassification if is_torch_available() else None,
|
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
|
||||||
"default": {
|
"default": {
|
||||||
"model": {
|
"model": {
|
||||||
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||||
@@ -138,16 +138,16 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"question-answering": {
|
"question-answering": {
|
||||||
"impl": QuestionAnsweringPipeline,
|
"impl": QuestionAnsweringPipeline,
|
||||||
"tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
|
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
|
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
|
||||||
"default": {
|
"default": {
|
||||||
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"table-question-answering": {
|
"table-question-answering": {
|
||||||
"impl": TableQuestionAnsweringPipeline,
|
"impl": TableQuestionAnsweringPipeline,
|
||||||
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None,
|
"pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
|
||||||
"tf": None,
|
"tf": (),
|
||||||
"default": {
|
"default": {
|
||||||
"model": {
|
"model": {
|
||||||
"pt": "google/tapas-base-finetuned-wtq",
|
"pt": "google/tapas-base-finetuned-wtq",
|
||||||
@@ -158,21 +158,21 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"fill-mask": {
|
"fill-mask": {
|
||||||
"impl": FillMaskPipeline,
|
"impl": FillMaskPipeline,
|
||||||
"tf": TFAutoModelForMaskedLM if is_tf_available() else None,
|
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForMaskedLM if is_torch_available() else None,
|
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
||||||
},
|
},
|
||||||
"summarization": {
|
"summarization": {
|
||||||
"impl": SummarizationPipeline,
|
"impl": SummarizationPipeline,
|
||||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
|
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
|
||||||
},
|
},
|
||||||
# This task is a special case as it's parametrized by SRC, TGT languages.
|
# This task is a special case as it's parametrized by SRC, TGT languages.
|
||||||
"translation": {
|
"translation": {
|
||||||
"impl": TranslationPipeline,
|
"impl": TranslationPipeline,
|
||||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||||
"default": {
|
"default": {
|
||||||
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
@@ -181,20 +181,20 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"text2text-generation": {
|
"text2text-generation": {
|
||||||
"impl": Text2TextGenerationPipeline,
|
"impl": Text2TextGenerationPipeline,
|
||||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
},
|
},
|
||||||
"text-generation": {
|
"text-generation": {
|
||||||
"impl": TextGenerationPipeline,
|
"impl": TextGenerationPipeline,
|
||||||
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
|
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForCausalLM if is_torch_available() else None,
|
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
||||||
},
|
},
|
||||||
"zero-shot-classification": {
|
"zero-shot-classification": {
|
||||||
"impl": ZeroShotClassificationPipeline,
|
"impl": ZeroShotClassificationPipeline,
|
||||||
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
|
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
|
||||||
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
|
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
||||||
"default": {
|
"default": {
|
||||||
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||||
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||||
@@ -203,14 +203,14 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"conversational": {
|
"conversational": {
|
||||||
"impl": ConversationalPipeline,
|
"impl": ConversationalPipeline,
|
||||||
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
|
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
||||||
"pt": AutoModelForCausalLM if is_torch_available() else None,
|
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
||||||
},
|
},
|
||||||
"image-classification": {
|
"image-classification": {
|
||||||
"impl": ImageClassificationPipeline,
|
"impl": ImageClassificationPipeline,
|
||||||
"tf": None,
|
"tf": (),
|
||||||
"pt": AutoModelForImageClassification if is_torch_available() else None,
|
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -379,53 +379,35 @@ def pipeline(
|
|||||||
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||||
>>> pipeline('ner', model=model, tokenizer=tokenizer)
|
>>> pipeline('ner', model=model, tokenizer=tokenizer)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Retrieve the task
|
# Retrieve the task
|
||||||
targeted_task, task_options = check_task(task)
|
targeted_task, task_options = check_task(task)
|
||||||
|
task_class = targeted_task["impl"]
|
||||||
|
|
||||||
# Use default model/config/tokenizer for the task if no model is provided
|
# Use default model/config/tokenizer for the task if no model is provided
|
||||||
if model is None:
|
if model is None:
|
||||||
# At that point framework might still be undetermined
|
# At that point framework might still be undetermined
|
||||||
model = get_default_model(targeted_task, framework, task_options)
|
model = get_default_model(targeted_task, framework, task_options)
|
||||||
|
|
||||||
|
# Config is the primordial information item.
|
||||||
|
# Instantiate config if needed
|
||||||
|
if isinstance(config, str):
|
||||||
|
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
|
||||||
|
elif config is None and isinstance(model, str):
|
||||||
|
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
|
||||||
|
|
||||||
model_name = model if isinstance(model, str) else None
|
model_name = model if isinstance(model, str) else None
|
||||||
|
|
||||||
# Infer the framework form the model
|
|
||||||
if framework is None:
|
|
||||||
framework, model = infer_framework_from_model(model, targeted_task, revision=revision, task=task)
|
|
||||||
|
|
||||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
|
||||||
|
|
||||||
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
||||||
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
||||||
|
|
||||||
# Instantiate config if needed
|
# Infer the framework from the model
|
||||||
if isinstance(config, str):
|
# Forced if framework already defined, inferred if it's None
|
||||||
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
|
# Will load the correct model if possible
|
||||||
|
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
||||||
# Instantiate model if needed
|
framework, model = infer_framework_load_model(
|
||||||
if isinstance(model, str):
|
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
|
||||||
# Handle transparent TF/PT model conversion
|
)
|
||||||
if framework == "pt" and model.endswith(".h5"):
|
|
||||||
model_kwargs["from_tf"] = True
|
|
||||||
logger.warning(
|
|
||||||
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
|
|
||||||
"Trying to load the model with PyTorch."
|
|
||||||
)
|
|
||||||
elif framework == "tf" and model.endswith(".bin"):
|
|
||||||
model_kwargs["from_pt"] = True
|
|
||||||
logger.warning(
|
|
||||||
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
|
||||||
"Trying to load the model with Tensorflow."
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_class is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import csv
|
import csv
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -21,11 +22,12 @@ import warnings
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from os.path import abspath, exists
|
from os.path import abspath, exists
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||||
from ..modelcard import ModelCard
|
from ..modelcard import ModelCard
|
||||||
|
from ..models.auto.configuration_auto import AutoConfig
|
||||||
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
|
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
|
||||||
@@ -48,8 +50,108 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_framework_load_model(
|
||||||
|
model,
|
||||||
|
config: AutoConfig,
|
||||||
|
model_classes: Optional[Dict[str, Tuple[type]]] = None,
|
||||||
|
task: Optional[str] = None,
|
||||||
|
framework: Optional[str] = None,
|
||||||
|
**model_kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
|
||||||
|
|
||||||
|
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
|
||||||
|
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
|
||||||
|
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
|
||||||
|
|
||||||
|
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
|
||||||
|
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
|
||||||
|
from.
|
||||||
|
config (:class:`~transformers.AutoConfig`):
|
||||||
|
The config associated with the model to help using the correct class
|
||||||
|
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
|
||||||
|
A mapping framework to class.
|
||||||
|
task (:obj:`str`):
|
||||||
|
The task defining which pipeline will be returned.
|
||||||
|
model_kwargs:
|
||||||
|
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
|
||||||
|
**model_kwargs)` function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Tuple`: A tuple framework, model.
|
||||||
|
"""
|
||||||
|
if not is_tf_available() and not is_torch_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||||
|
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||||
|
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||||
|
)
|
||||||
|
if isinstance(model, str):
|
||||||
|
model_kwargs["_from_pipeline"] = task
|
||||||
|
class_tuple = ()
|
||||||
|
look_pt = is_torch_available() and framework in {"pt", None}
|
||||||
|
look_tf = is_tf_available() and framework in {"tf", None}
|
||||||
|
if model_classes:
|
||||||
|
if look_pt:
|
||||||
|
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
|
||||||
|
if look_tf:
|
||||||
|
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
|
||||||
|
if config.architectures:
|
||||||
|
classes = []
|
||||||
|
for architecture in config.architectures:
|
||||||
|
transformers_module = importlib.import_module("transformers")
|
||||||
|
if look_tf:
|
||||||
|
_class = getattr(transformers_module, architecture, None)
|
||||||
|
if _class is not None:
|
||||||
|
classes.append(_class)
|
||||||
|
if look_pt:
|
||||||
|
_class = getattr(transformers_module, f"TF{architecture}", None)
|
||||||
|
if _class is not None:
|
||||||
|
classes.append(_class)
|
||||||
|
class_tuple = class_tuple + tuple(classes)
|
||||||
|
|
||||||
|
if len(class_tuple) == 0:
|
||||||
|
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
|
||||||
|
|
||||||
|
for model_class in class_tuple:
|
||||||
|
kwargs = model_kwargs.copy()
|
||||||
|
if framework == "pt" and model.endswith(".h5"):
|
||||||
|
kwargs["from_tf"] = True
|
||||||
|
logger.warning(
|
||||||
|
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
|
||||||
|
"Trying to load the model with PyTorch."
|
||||||
|
)
|
||||||
|
elif framework == "tf" and model.endswith(".bin"):
|
||||||
|
kwargs["from_pt"] = True
|
||||||
|
logger.warning(
|
||||||
|
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
||||||
|
"Trying to load the model with Tensorflow."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = model_class.from_pretrained(model, **kwargs)
|
||||||
|
# Stop loading on the first successful load.
|
||||||
|
break
|
||||||
|
except (OSError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(model, str):
|
||||||
|
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
||||||
|
|
||||||
|
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||||
|
return framework, model
|
||||||
|
|
||||||
|
|
||||||
def infer_framework_from_model(
|
def infer_framework_from_model(
|
||||||
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs
|
model,
|
||||||
|
model_classes: Optional[Dict[str, Tuple[type]]] = None,
|
||||||
|
task: Optional[str] = None,
|
||||||
|
framework: Optional[str] = None,
|
||||||
|
**model_kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
|
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
|
||||||
@@ -75,30 +177,13 @@ def infer_framework_from_model(
|
|||||||
Returns:
|
Returns:
|
||||||
:obj:`Tuple`: A tuple framework, model.
|
:obj:`Tuple`: A tuple framework, model.
|
||||||
"""
|
"""
|
||||||
if not is_tf_available() and not is_torch_available():
|
|
||||||
raise RuntimeError(
|
|
||||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
|
||||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
|
||||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
|
||||||
)
|
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
model_kwargs["_from_pipeline"] = task
|
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
|
||||||
if is_torch_available() and not is_tf_available():
|
else:
|
||||||
model_class = model_classes.get("pt", AutoModel)
|
config = model.config
|
||||||
model = model_class.from_pretrained(model, **model_kwargs)
|
return infer_framework_load_model(
|
||||||
elif is_tf_available() and not is_torch_available():
|
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
|
||||||
model_class = model_classes.get("tf", TFAutoModel)
|
)
|
||||||
model = model_class.from_pretrained(model, **model_kwargs)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
model_class = model_classes.get("pt", AutoModel)
|
|
||||||
model = model_class.from_pretrained(model, **model_kwargs)
|
|
||||||
except OSError:
|
|
||||||
model_class = model_classes.get("tf", TFAutoModel)
|
|
||||||
model = model_class.from_pretrained(model, **model_kwargs)
|
|
||||||
|
|
||||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
|
||||||
return framework, model
|
|
||||||
|
|
||||||
|
|
||||||
def get_framework(model, revision: Optional[str] = None):
|
def get_framework(model, revision: Optional[str] = None):
|
||||||
@@ -534,7 +619,7 @@ class Pipeline(_ScikitCompat):
|
|||||||
):
|
):
|
||||||
|
|
||||||
if framework is None:
|
if framework is None:
|
||||||
framework, model = infer_framework_from_model(model)
|
framework, model = infer_framework_load_model(model, config=model.config)
|
||||||
|
|
||||||
self.task = task
|
self.task = task
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
BlenderbotSmallForConditionalGeneration,
|
||||||
|
BlenderbotSmallTokenizer,
|
||||||
Conversation,
|
Conversation,
|
||||||
ConversationalPipeline,
|
ConversationalPipeline,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
|||||||
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
|
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
|
||||||
self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
|
self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
|
||||||
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
|
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_from_pipeline_conversation(self):
|
||||||
|
model_id = "facebook/blenderbot_small-90M"
|
||||||
|
|
||||||
|
# from model id
|
||||||
|
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)
|
||||||
|
|
||||||
|
# from model object
|
||||||
|
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
|
||||||
|
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
|
||||||
|
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
conversation = Conversation("My name is Sarah and I live in London")
|
||||||
|
conversation_copy = Conversation("My name is Sarah and I live in London")
|
||||||
|
|
||||||
|
result_model_id = conversation_agent_from_model_id([conversation])
|
||||||
|
result_model = conversation_agent_from_model([conversation_copy])
|
||||||
|
|
||||||
|
# check for equality
|
||||||
|
self.assertEqual(
|
||||||
|
result_model_id.generated_responses[0],
|
||||||
|
"hi sarah, i live in london as well. do you have any plans for the weekend?",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
result_model_id.generated_responses[0],
|
||||||
|
result_model.generated_responses[0],
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user