BLIP: fix generation after hub update (#34876)

* fix blip generation

* dont remove it yet

* Update src/transformers/models/blip_2/modeling_blip_2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* address comments

* modular

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-11-25 10:41:55 +01:00
committed by GitHub
parent c1a8520419
commit 098962dac2
7 changed files with 42 additions and 35 deletions

View File

@@ -1994,8 +1994,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
print(predictions[0].tolist(), generated_text)
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual("a woman sitting on the beach with a dog", generated_text)
# image and context
@@ -2007,10 +2007,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
self.assertEqual(
predictions[0].tolist(),
[2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118],
)
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach")
def test_inference_interpolate_pos_encoding(self):
@@ -2026,7 +2024,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118])
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 8, 2335, 15, 5, 4105, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual(generated_text, "a woman and dog on the beach")
def test_inference_opt_batched_beam_search(self):
@@ -2042,8 +2041,9 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
predictions = model.generate(**inputs, num_beams=2)
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual(predictions[1].tolist(), expected_ids)
def test_inference_t5(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
@@ -2070,10 +2070,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
self.assertEqual(
predictions[0].tolist(),
[0, 3, 7, 152, 67, 839, 1],
)
self.assertEqual(predictions[0].tolist(), [0, 3, 7, 152, 67, 839, 1])
self.assertEqual(generated_text, "san diego")
def test_inference_t5_batched_beam_search(self):

View File

@@ -945,7 +945,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
# Add args to the config to trigger new logic when inputs are expanded in processing file
processor.num_query_tokens = model.config.num_query_tokens
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
model.config.image_token_index = len(processor.tokenizer) - 1
model.config.image_token_index = len(processor.tokenizer) - 2
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
# Generate again with new inputs