Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. (#7728)
* Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. * Actually changing the default resolution order to account for model defaults Adding a new tests for each pipeline to check that pipeline(task) works too without manually adding the framework too.
This commit is contained in:
@@ -10,6 +10,7 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||
|
||||
NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"]
|
||||
TF_NER_FINETUNED_MODELS = ["Narsil/small"]
|
||||
|
||||
# xlnet-base-cased disabled for now, since it crashes TF2
|
||||
FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"]
|
||||
@@ -804,6 +805,14 @@ class NerPipelineTests(unittest.TestCase):
|
||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
|
||||
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||
|
||||
@require_tf
|
||||
def test_tf_only_ner(self):
|
||||
mandatory_keys = {"entity", "word", "score"}
|
||||
for model_name in TF_NER_FINETUNED_MODELS:
|
||||
# We don't specificy framework='tf' but it gets detected automatically
|
||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
|
||||
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||
|
||||
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
pipelines = SUPPORTED_TASKS.keys()
|
||||
@@ -815,6 +824,7 @@ class PipelineCommonTests(unittest.TestCase):
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing TF defaults with TF and {}".format(task)):
|
||||
pipeline(task, framework="tf")
|
||||
pipeline(task)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@@ -823,3 +833,4 @@ class PipelineCommonTests(unittest.TestCase):
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||
pipeline(task, framework="pt")
|
||||
pipeline(task)
|
||||
|
||||
Reference in New Issue
Block a user