From 657bb14f981c8d7e3ad77fe309bee0951cbf7186 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 8 Jan 2025 15:14:17 +0000 Subject: [PATCH] Enable auto task for timm models in pipeline (#35531) * Enable auto task for timm models * Add pipeline test --- src/transformers/pipelines/__init__.py | 2 +- .../timm_wrapper/test_modeling_timm_wrapper.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 07156b3cf1..257f5689b0 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -494,7 +494,7 @@ def get_task(model: str, token: Optional[str] = None, **deprecated_kwargs) -> st raise RuntimeError( f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically" ) - if getattr(info, "library_name", "transformers") != "transformers": + if getattr(info, "library_name", "transformers") not in {"transformers", "timm"}: raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers") task = info.pipeline_tag return task diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index cf35d90518..360bbc9371 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -17,6 +17,7 @@ import inspect import tempfile import unittest +from transformers import pipeline from transformers.testing_utils import ( require_bitsandbytes, require_timm, @@ -294,6 +295,19 @@ class TimmWrapperModelIntegrationTest(unittest.TestCase): is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3) self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}") + @slow + def test_inference_with_pipeline(self): + image = prepare_img() + classifier = pipeline(model="timm/resnet18.a1_in1k", device=torch_device) + result = classifier(image) + + # verify result + expected_label = "tabby, tabby cat" + expected_score = 0.4329 + + self.assertEqual(result[0]["label"], expected_label) + self.assertAlmostEqual(result[0]["score"], expected_score, places=3) + @slow @require_bitsandbytes def test_inference_image_classification_quantized(self):