Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. (#7728)
* Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. * Actually changing the default resolution order to account for model defaults Adding a new tests for each pipeline to check that pipeline(task) works too without manually adding the framework too.
This commit is contained in:
@@ -85,31 +85,63 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_framework(model=None):
|
def get_framework(model):
|
||||||
"""
|
"""
|
||||||
Select framework (TensorFlow or PyTorch) to use.
|
Select framework (TensorFlow or PyTorch) to use.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`):
|
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
|
||||||
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
|
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
|
||||||
the model name). If no specific model is provided, defaults to using PyTorch.
|
the model name). If no specific model is provided, defaults to using PyTorch.
|
||||||
"""
|
"""
|
||||||
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
|
if not is_tf_available() and not is_torch_available():
|
||||||
# Both framework are available but the user supplied a model class instance.
|
|
||||||
# Try to guess which framework to use from the model classname
|
|
||||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
|
||||||
elif not is_tf_available() and not is_torch_available():
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||||
)
|
)
|
||||||
|
if isinstance(model, str):
|
||||||
|
if is_torch_available() and not is_tf_available():
|
||||||
|
model = AutoModel.from_pretrained(model)
|
||||||
|
elif is_tf_available() and not is_torch_available():
|
||||||
|
model = TFAutoModel.from_pretrained(model)
|
||||||
else:
|
else:
|
||||||
# framework = 'tf' if is_tf_available() else 'pt'
|
try:
|
||||||
framework = "pt" if is_torch_available() else "tf"
|
model = AutoModel.from_pretrained(model)
|
||||||
|
except OSError:
|
||||||
|
model = TFAutoModel.from_pretrained(model)
|
||||||
|
|
||||||
|
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||||
return framework
|
return framework
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
targeted_task (:obj:`Dict` ):
|
||||||
|
Dictionnary representing the given task, that should contain default models
|
||||||
|
|
||||||
|
framework (:obj:`str`, None)
|
||||||
|
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
|
||||||
|
:obj:`str` The model string representing the default model for this pipeline
|
||||||
|
"""
|
||||||
|
if is_torch_available() and not is_tf_available():
|
||||||
|
framework = "pt"
|
||||||
|
elif is_tf_available() and not is_torch_available():
|
||||||
|
framework = "tf"
|
||||||
|
|
||||||
|
default_models = targeted_task["default"]["model"]
|
||||||
|
if framework is None:
|
||||||
|
framework = "pt"
|
||||||
|
|
||||||
|
return default_models[framework]
|
||||||
|
|
||||||
|
|
||||||
class PipelineException(Exception):
|
class PipelineException(Exception):
|
||||||
"""
|
"""
|
||||||
Raised by a :class:`~transformers.Pipeline` when handling __call__.
|
Raised by a :class:`~transformers.Pipeline` when handling __call__.
|
||||||
@@ -2685,14 +2717,16 @@ def pipeline(
|
|||||||
if task not in SUPPORTED_TASKS:
|
if task not in SUPPORTED_TASKS:
|
||||||
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
|
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
|
||||||
|
|
||||||
framework = framework or get_framework(model)
|
|
||||||
|
|
||||||
targeted_task = SUPPORTED_TASKS[task]
|
targeted_task = SUPPORTED_TASKS[task]
|
||||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
|
||||||
|
|
||||||
# Use default model/config/tokenizer for the task if no model is provided
|
# Use default model/config/tokenizer for the task if no model is provided
|
||||||
if model is None:
|
if model is None:
|
||||||
model = targeted_task["default"]["model"][framework]
|
# At that point framework might still be undetermined
|
||||||
|
model = get_default_model(targeted_task, framework)
|
||||||
|
|
||||||
|
framework = framework or get_framework(model)
|
||||||
|
|
||||||
|
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
||||||
|
|
||||||
# Try to infer tokenizer from model or config name (if provided as str)
|
# Try to infer tokenizer from model or config name (if provided as str)
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
|||||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||||
|
|
||||||
NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"]
|
NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"]
|
||||||
|
TF_NER_FINETUNED_MODELS = ["Narsil/small"]
|
||||||
|
|
||||||
# xlnet-base-cased disabled for now, since it crashes TF2
|
# xlnet-base-cased disabled for now, since it crashes TF2
|
||||||
FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"]
|
FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"]
|
||||||
@@ -804,6 +805,14 @@ class NerPipelineTests(unittest.TestCase):
|
|||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
|
||||||
self._test_ner_pipeline(nlp, mandatory_keys)
|
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_only_ner(self):
|
||||||
|
mandatory_keys = {"entity", "word", "score"}
|
||||||
|
for model_name in TF_NER_FINETUNED_MODELS:
|
||||||
|
# We don't specificy framework='tf' but it gets detected automatically
|
||||||
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
|
||||||
|
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||||
|
|
||||||
|
|
||||||
class PipelineCommonTests(unittest.TestCase):
|
class PipelineCommonTests(unittest.TestCase):
|
||||||
pipelines = SUPPORTED_TASKS.keys()
|
pipelines = SUPPORTED_TASKS.keys()
|
||||||
@@ -815,6 +824,7 @@ class PipelineCommonTests(unittest.TestCase):
|
|||||||
for task in self.pipelines:
|
for task in self.pipelines:
|
||||||
with self.subTest(msg="Testing TF defaults with TF and {}".format(task)):
|
with self.subTest(msg="Testing TF defaults with TF and {}".format(task)):
|
||||||
pipeline(task, framework="tf")
|
pipeline(task, framework="tf")
|
||||||
|
pipeline(task)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
@@ -823,3 +833,4 @@ class PipelineCommonTests(unittest.TestCase):
|
|||||||
for task in self.pipelines:
|
for task in self.pipelines:
|
||||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||||
pipeline(task, framework="pt")
|
pipeline(task, framework="pt")
|
||||||
|
pipeline(task)
|
||||||
|
|||||||
Reference in New Issue
Block a user