fix output data type of image classification (#31444)
* fix output data type of image classification * add tests for low-precision pipeline * add bf16 pipeline tests * fix bf16 tests * Update tests/pipelines/test_pipelines_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix import * fix import torch * fix style --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,8 @@ if is_tf_available():
|
|||||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -180,6 +182,9 @@ class ImageClassificationPipeline(Pipeline):
|
|||||||
top_k = self.model.config.num_labels
|
top_k = self.model.config.num_labels
|
||||||
|
|
||||||
outputs = model_outputs["logits"][0]
|
outputs = model_outputs["logits"][0]
|
||||||
|
if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16):
|
||||||
|
outputs = outputs.to(torch.float32).numpy()
|
||||||
|
else:
|
||||||
outputs = outputs.numpy()
|
outputs = outputs.numpy()
|
||||||
|
|
||||||
if function_to_apply == ClassificationFunction.SIGMOID:
|
if function_to_apply == ClassificationFunction.SIGMOID:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from transformers import (
|
|||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
||||||
@@ -34,6 +35,9 @@ from transformers.testing_utils import (
|
|||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
else:
|
else:
|
||||||
@@ -177,6 +181,30 @@ class ImageClassificationPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertIs(image_classifier.tokenizer, tokenizer)
|
self.assertIs(image_classifier.tokenizer, tokenizer)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_float16_pipeline(self):
|
||||||
|
image_classifier = pipeline(
|
||||||
|
"image-classification", model="hf-internal-testing/tiny-random-vit", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=3),
|
||||||
|
[{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_bfloat16_pipeline(self):
|
||||||
|
image_classifier = pipeline(
|
||||||
|
"image-classification", model="hf-internal-testing/tiny-random-vit", torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=3),
|
||||||
|
[{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_perceiver(self):
|
def test_perceiver(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user