Fix chameleon tests (#38565)

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-06-04 10:13:35 +02:00
committed by GitHub
parent 55736eea99
commit 3c995c1fdc

View File

@@ -21,6 +21,7 @@ from parameterized import parameterized
from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
Expectations,
require_bitsandbytes, require_bitsandbytes,
require_read_token, require_read_token,
require_torch, require_torch,
@@ -417,7 +418,14 @@ class ChameleonIntegrationTest(unittest.TestCase):
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16) inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16)
# greedy generation outputs # greedy generation outputs
EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in'] # fmt: skip EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("cuda", 7): ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in'],
("cuda", 8): ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot representing the position of the star Alpha Centauri. Alpha Centauri is the brightest star in the constellation Centaurus and is located'],
}
) # fmt: skip
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True) text = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@@ -447,10 +455,20 @@ class ChameleonIntegrationTest(unittest.TestCase):
) )
# greedy generation outputs # greedy generation outputs
EXPECTED_TEXT_COMPLETION = [ EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("cuda", 7): [
'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and',
'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.',
],
("cuda", 8): [
'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in', 'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in',
'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.' 'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.',
] # fmt: skip ],
}
) # fmt: skip
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True) text = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)