docs: fix return type annotation of get_default_model_revision (#35982)
This commit is contained in:
committed by
GitHub
parent
6a1ab634b6
commit
3c912c9089
@@ -384,7 +384,7 @@ def get_framework(model, revision: Optional[str] = None):
|
|||||||
|
|
||||||
def get_default_model_and_revision(
|
def get_default_model_and_revision(
|
||||||
targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
|
targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
|
||||||
) -> Union[str, Tuple[str, str]]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
|
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
|
||||||
|
|
||||||
@@ -401,7 +401,9 @@ def get_default_model_and_revision(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
|
||||||
`str` The model string representing the default model for this pipeline
|
Tuple:
|
||||||
|
- `str` The model string representing the default model for this pipeline.
|
||||||
|
- `str` The revision of the model.
|
||||||
"""
|
"""
|
||||||
if is_torch_available() and not is_tf_available():
|
if is_torch_available() and not is_tf_available():
|
||||||
framework = "pt"
|
framework = "pt"
|
||||||
|
|||||||
@@ -796,7 +796,7 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
pipeline_class=PairClassificationPipeline,
|
pipeline_class=PairClassificationPipeline,
|
||||||
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
|
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
|
||||||
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
|
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||||
default={"pt": "hf-internal-testing/tiny-random-distilbert"},
|
default={"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")},
|
||||||
type="text",
|
type="text",
|
||||||
)
|
)
|
||||||
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
||||||
@@ -806,7 +806,9 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
|
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
|
||||||
self.assertEqual(task_def["type"], "text")
|
self.assertEqual(task_def["type"], "text")
|
||||||
self.assertEqual(task_def["impl"], PairClassificationPipeline)
|
self.assertEqual(task_def["impl"], PairClassificationPipeline)
|
||||||
self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}})
|
self.assertEqual(
|
||||||
|
task_def["default"], {"model": {"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}}
|
||||||
|
)
|
||||||
|
|
||||||
# Clean registry for next tests.
|
# Clean registry for next tests.
|
||||||
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
|
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
|
||||||
|
|||||||
Reference in New Issue
Block a user