Enable auto task for timm models in pipeline (#35531)
* Enable auto task for timm models * Add pipeline test
This commit is contained in:
committed by
GitHub
parent
1a6c1d3a9a
commit
657bb14f98
@@ -494,7 +494,7 @@ def get_task(model: str, token: Optional[str] = None, **deprecated_kwargs) -> st
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
||||||
)
|
)
|
||||||
if getattr(info, "library_name", "transformers") != "transformers":
|
if getattr(info, "library_name", "transformers") not in {"transformers", "timm"}:
|
||||||
raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
|
raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
|
||||||
task = info.pipeline_tag
|
task = info.pipeline_tag
|
||||||
return task
|
return task
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import inspect
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from transformers import pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_timm,
|
require_timm,
|
||||||
@@ -294,6 +295,19 @@ class TimmWrapperModelIntegrationTest(unittest.TestCase):
|
|||||||
is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3)
|
is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3)
|
||||||
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
|
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_with_pipeline(self):
|
||||||
|
image = prepare_img()
|
||||||
|
classifier = pipeline(model="timm/resnet18.a1_in1k", device=torch_device)
|
||||||
|
result = classifier(image)
|
||||||
|
|
||||||
|
# verify result
|
||||||
|
expected_label = "tabby, tabby cat"
|
||||||
|
expected_score = 0.4329
|
||||||
|
|
||||||
|
self.assertEqual(result[0]["label"], expected_label)
|
||||||
|
self.assertAlmostEqual(result[0]["score"], expected_score, places=3)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_inference_image_classification_quantized(self):
|
def test_inference_image_classification_quantized(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user