Paligemma support for multi-image (#33447)

* upadte

* Update src/transformers/models/paligemma/processing_paligemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* update docs

* better example in tests

* support image tokens

* read token

* Update tests/models/paligemma/test_processing_paligemma.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* nit: naming

* Update docs/source/en/model_doc/paligemma.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* conflicts after rebasing

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-09-27 11:23:14 +02:00
committed by GitHub
parent 55b7a0404e
commit 3e039d3827
4 changed files with 222 additions and 49 deletions

View File

@@ -326,8 +326,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
@slow
@require_read_token
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
@@ -349,8 +347,40 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT,
)
@slow
@require_read_token
def test_small_model_integration_test_multiimage(self):
model_id = "google/paligemma-3b-ft-nlvr2-448" # checkpoint tuned for multiple images
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = "answer en There is no snowman in any of the images. Is this true or false?"
stop_sign_image = Image.open(
requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw
)
snow_image = Image.open(
requests.get(
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", stream=True
).raw
)
inputs = processor(text=prompt, images=[[snow_image, snow_image]], return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "answer en There is no snowman in any of the images. Is this true or false?\nFalse"
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
# try another prompt with two different image this time
prompt = "answer en There is exactly one snowman. Is this true or false?"
inputs = processor(text=prompt, images=[[snow_image, stop_sign_image]], return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "answer en There is exactly one snowman. Is this true or false?\nTrue"
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
def test_small_model_integration_test_paligemma_VQA(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
@@ -370,8 +400,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT,
)
@slow
@require_read_token
def test_small_model_integration_test_paligemma_empty_prompt(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
@@ -392,8 +420,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT,
)
@slow
@require_read_token
def test_small_model_integration_test_paligemma_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
@@ -420,9 +446,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_bf16(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
@@ -452,9 +475,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_f16(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
@@ -485,9 +505,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_torch
@require_read_token
def test_integration_detection_bug(self):
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
# impacted negatively segmentation generations.
@@ -511,8 +528,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_read_token
def test_paligemma_index_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
@@ -536,9 +551,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
@slow
@require_torch
@require_read_token
def test_paligemma_finetuning_with_suffixes_bf16(self):
# this is a supplementary test to ensure paligemma fine-tuning that relies on token_type_ids is robust to future changes
model_id = "google/paligemma-3b-pt-224"

View File

@@ -0,0 +1,84 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from transformers import AutoProcessor, GemmaTokenizerFast, PaliGemmaProcessor
from transformers.testing_utils import require_read_token, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import SiglipImageProcessor
@require_vision
@require_read_token
class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PaliGemmaProcessor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SiglipImageProcessor(do_center_crop=False)
tokenizer = GemmaTokenizerFast.from_pretrained("google/gemma-7b")
image_processor.image_seq_length = 32
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_text_with_image_tokens(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
text_multi_images = "<image><image><bos>Dummy text!"
text_single_image = "<image><bos>Dummy text!"
text_no_image = "Dummy text!"
image = self.prepare_image_inputs()[0]
out_noimage = processor(text=text_no_image, images=image, return_tensors="np")
out_singlimage = processor(text=text_single_image, images=image, return_tensors="np")
for k in out_noimage:
self.assertTrue(out_noimage[k].tolist() == out_singlimage[k].tolist())
out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
out_noimage = processor(text=text_no_image, images=[[image, image]], return_tensors="np")
# We can't be sure what is users intention, whether user want "one text + two images" or user forgot to add the second text
with self.assertRaises(ValueError):
out_noimage = processor(text=text_no_image, images=[image, image], return_tensors="np")
for k in out_noimage:
self.assertTrue(out_noimage[k].tolist() == out_multiimages[k].tolist())
text_batched = ["Dummy text!", "Dummy text!"]
text_batched_with_image = ["<image><bos>Dummy text!", "<image><bos>Dummy text!"]
out_images = processor(text=text_batched_with_image, images=[image, image], return_tensors="np")
out_noimage_nested = processor(text=text_batched, images=[[image], [image]], return_tensors="np")
out_noimage = processor(text=text_batched, images=[image, image], return_tensors="np")
for k in out_noimage:
self.assertTrue(out_noimage[k].tolist() == out_images[k].tolist() == out_noimage_nested[k].tolist())