Paligemma support for multi-image (#33447)
* upadte * Update src/transformers/models/paligemma/processing_paligemma.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * update docs * better example in tests * support image tokens * read token * Update tests/models/paligemma/test_processing_paligemma.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * nit: naming * Update docs/source/en/model_doc/paligemma.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * conflicts after rebasing --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
55b7a0404e
commit
3e039d3827
@@ -326,8 +326,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@@ -349,8 +347,40 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_multiimage(self):
|
||||
model_id = "google/paligemma-3b-ft-nlvr2-448" # checkpoint tuned for multiple images
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
|
||||
processor = PaliGemmaProcessor.from_pretrained(model_id)
|
||||
prompt = "answer en There is no snowman in any of the images. Is this true or false?"
|
||||
stop_sign_image = Image.open(
|
||||
requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw
|
||||
)
|
||||
snow_image = Image.open(
|
||||
requests.get(
|
||||
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", stream=True
|
||||
).raw
|
||||
)
|
||||
|
||||
inputs = processor(text=prompt, images=[[snow_image, snow_image]], return_tensors="pt")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "answer en There is no snowman in any of the images. Is this true or false?\nFalse"
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
# try another prompt with two different image this time
|
||||
prompt = "answer en There is exactly one snowman. Is this true or false?"
|
||||
inputs = processor(text=prompt, images=[[snow_image, stop_sign_image]], return_tensors="pt")
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "answer en There is exactly one snowman. Is this true or false?\nTrue"
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
def test_small_model_integration_test_paligemma_VQA(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@@ -370,8 +400,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_empty_prompt(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@@ -392,8 +420,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_batched(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@@ -420,9 +446,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_batched_bf16(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@@ -452,9 +475,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_batched_f16(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@@ -485,9 +505,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_integration_detection_bug(self):
|
||||
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
|
||||
# impacted negatively segmentation generations.
|
||||
@@ -511,8 +528,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
|
||||
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_paligemma_index_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
|
||||
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
|
||||
@@ -536,9 +551,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_paligemma_finetuning_with_suffixes_bf16(self):
|
||||
# this is a supplementary test to ensure paligemma fine-tuning that relies on token_type_ids is robust to future changes
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
|
||||
84
tests/models/paligemma/test_processing_paligemma.py
Normal file
84
tests/models/paligemma/test_processing_paligemma.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, GemmaTokenizerFast, PaliGemmaProcessor
|
||||
from transformers.testing_utils import require_read_token, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import SiglipImageProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_read_token
|
||||
class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = PaliGemmaProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = SiglipImageProcessor(do_center_crop=False)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/gemma-7b")
|
||||
image_processor.image_seq_length = 32
|
||||
|
||||
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_text_with_image_tokens(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
text_multi_images = "<image><image><bos>Dummy text!"
|
||||
text_single_image = "<image><bos>Dummy text!"
|
||||
text_no_image = "Dummy text!"
|
||||
|
||||
image = self.prepare_image_inputs()[0]
|
||||
|
||||
out_noimage = processor(text=text_no_image, images=image, return_tensors="np")
|
||||
out_singlimage = processor(text=text_single_image, images=image, return_tensors="np")
|
||||
for k in out_noimage:
|
||||
self.assertTrue(out_noimage[k].tolist() == out_singlimage[k].tolist())
|
||||
|
||||
out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
|
||||
out_noimage = processor(text=text_no_image, images=[[image, image]], return_tensors="np")
|
||||
|
||||
# We can't be sure what is users intention, whether user want "one text + two images" or user forgot to add the second text
|
||||
with self.assertRaises(ValueError):
|
||||
out_noimage = processor(text=text_no_image, images=[image, image], return_tensors="np")
|
||||
|
||||
for k in out_noimage:
|
||||
self.assertTrue(out_noimage[k].tolist() == out_multiimages[k].tolist())
|
||||
|
||||
text_batched = ["Dummy text!", "Dummy text!"]
|
||||
text_batched_with_image = ["<image><bos>Dummy text!", "<image><bos>Dummy text!"]
|
||||
out_images = processor(text=text_batched_with_image, images=[image, image], return_tensors="np")
|
||||
out_noimage_nested = processor(text=text_batched, images=[[image], [image]], return_tensors="np")
|
||||
out_noimage = processor(text=text_batched, images=[image, image], return_tensors="np")
|
||||
for k in out_noimage:
|
||||
self.assertTrue(out_noimage[k].tolist() == out_images[k].tolist() == out_noimage_nested[k].tolist())
|
||||
Reference in New Issue
Block a user