Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. (#7728)
* Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. * Actually changing the default resolution order to account for model defaults Adding a new tests for each pipeline to check that pipeline(task) works too without manually adding the framework too.
This commit is contained in:
@@ -85,31 +85,63 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_framework(model=None):
|
||||
def get_framework(model):
|
||||
"""
|
||||
Select framework (TensorFlow or PyTorch) to use.
|
||||
|
||||
Args:
|
||||
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`):
|
||||
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
|
||||
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
|
||||
the model name). If no specific model is provided, defaults to using PyTorch.
|
||||
"""
|
||||
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
|
||||
# Both framework are available but the user supplied a model class instance.
|
||||
# Try to guess which framework to use from the model classname
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
elif not is_tf_available() and not is_torch_available():
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
raise RuntimeError(
|
||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||
)
|
||||
else:
|
||||
# framework = 'tf' if is_tf_available() else 'pt'
|
||||
framework = "pt" if is_torch_available() else "tf"
|
||||
if isinstance(model, str):
|
||||
if is_torch_available() and not is_tf_available():
|
||||
model = AutoModel.from_pretrained(model)
|
||||
elif is_tf_available() and not is_torch_available():
|
||||
model = TFAutoModel.from_pretrained(model)
|
||||
else:
|
||||
try:
|
||||
model = AutoModel.from_pretrained(model)
|
||||
except OSError:
|
||||
model = TFAutoModel.from_pretrained(model)
|
||||
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
return framework
|
||||
|
||||
|
||||
def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
|
||||
"""
|
||||
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
|
||||
|
||||
Args:
|
||||
targeted_task (:obj:`Dict` ):
|
||||
Dictionnary representing the given task, that should contain default models
|
||||
|
||||
framework (:obj:`str`, None)
|
||||
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
|
||||
|
||||
Returns
|
||||
|
||||
:obj:`str` The model string representing the default model for this pipeline
|
||||
"""
|
||||
if is_torch_available() and not is_tf_available():
|
||||
framework = "pt"
|
||||
elif is_tf_available() and not is_torch_available():
|
||||
framework = "tf"
|
||||
|
||||
default_models = targeted_task["default"]["model"]
|
||||
if framework is None:
|
||||
framework = "pt"
|
||||
|
||||
return default_models[framework]
|
||||
|
||||
|
||||
class PipelineException(Exception):
|
||||
"""
|
||||
Raised by a :class:`~transformers.Pipeline` when handling __call__.
|
||||
@@ -2685,14 +2717,16 @@ def pipeline(
|
||||
if task not in SUPPORTED_TASKS:
|
||||
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
|
||||
|
||||
framework = framework or get_framework(model)
|
||||
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
model = targeted_task["default"]["model"][framework]
|
||||
# At that point framework might still be undetermined
|
||||
model = get_default_model(targeted_task, framework)
|
||||
|
||||
framework = framework or get_framework(model)
|
||||
|
||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
||||
|
||||
# Try to infer tokenizer from model or config name (if provided as str)
|
||||
if tokenizer is None:
|
||||
|
||||
@@ -10,6 +10,7 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||
|
||||
NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"]
|
||||
TF_NER_FINETUNED_MODELS = ["Narsil/small"]
|
||||
|
||||
# xlnet-base-cased disabled for now, since it crashes TF2
|
||||
FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"]
|
||||
@@ -804,6 +805,14 @@ class NerPipelineTests(unittest.TestCase):
|
||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
|
||||
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||
|
||||
@require_tf
|
||||
def test_tf_only_ner(self):
|
||||
mandatory_keys = {"entity", "word", "score"}
|
||||
for model_name in TF_NER_FINETUNED_MODELS:
|
||||
# We don't specificy framework='tf' but it gets detected automatically
|
||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
|
||||
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||
|
||||
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
pipelines = SUPPORTED_TASKS.keys()
|
||||
@@ -815,6 +824,7 @@ class PipelineCommonTests(unittest.TestCase):
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing TF defaults with TF and {}".format(task)):
|
||||
pipeline(task, framework="tf")
|
||||
pipeline(task)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@@ -823,3 +833,4 @@ class PipelineCommonTests(unittest.TestCase):
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||
pipeline(task, framework="pt")
|
||||
pipeline(task)
|
||||
|
||||
Reference in New Issue
Block a user