From 2be8a9098e06262bdd5c16b5e8a70f145df88e96 Mon Sep 17 00:00:00 2001 From: raghavanone <115454562+raghavanone@users.noreply.github.com> Date: Thu, 31 Aug 2023 19:38:20 +0530 Subject: [PATCH] Save image_processor while saving pipeline (ImageSegmentationPipeline) (#25884) * Save image_processor while saving pipeline (ImageSegmentationPipeline) * Fix black issues --- src/transformers/pipelines/base.py | 3 +++ .../test_pipelines_image_segmentation.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 153e8e9f6b..f1af0f865b 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -872,6 +872,9 @@ class Pipeline(_ScikitCompat): if self.feature_extractor is not None: self.feature_extractor.save_pretrained(save_directory) + if self.image_processor is not None: + self.image_processor.save_pretrained(save_directory) + if self.modelcard is not None: self.modelcard.save_pretrained(save_directory) diff --git a/tests/pipelines/test_pipelines_image_segmentation.py b/tests/pipelines/test_pipelines_image_segmentation.py index 72150fe739..dbc0c0db80 100644 --- a/tests/pipelines/test_pipelines_image_segmentation.py +++ b/tests/pipelines/test_pipelines_image_segmentation.py @@ -13,6 +13,7 @@ # limitations under the License. import hashlib +import tempfile import unittest from typing import Dict @@ -714,3 +715,17 @@ class ImageSegmentationPipelineTests(unittest.TestCase): }, ], ) + + def test_save_load(self): + model_id = "hf-internal-testing/tiny-detr-mobilenetsv3-panoptic" + + model = AutoModelForImageSegmentation.from_pretrained(model_id) + image_processor = AutoImageProcessor.from_pretrained(model_id) + image_segmenter = pipeline( + task="image-segmentation", + model=model, + image_processor=image_processor, + ) + with tempfile.TemporaryDirectory() as tmpdirname: + image_segmenter.save_pretrained(tmpdirname) + pipeline(task="image-segmentation", model=tmpdirname)