BLIP: fix generation after hub update (#34876)
* fix blip generation * dont remove it yet * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments * modular --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c1a8520419
commit
098962dac2
@@ -1994,8 +1994,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
print(predictions[0].tolist(), generated_text)
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual("a woman sitting on the beach with a dog", generated_text)
|
||||
|
||||
# image and context
|
||||
@@ -2007,10 +2007,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
self.assertEqual(
|
||||
predictions[0].tolist(),
|
||||
[2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118],
|
||||
)
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach")
|
||||
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
@@ -2026,7 +2024,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118])
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 8, 2335, 15, 5, 4105, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual(generated_text, "a woman and dog on the beach")
|
||||
|
||||
def test_inference_opt_batched_beam_search(self):
|
||||
@@ -2042,8 +2041,9 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
# Test output (in this case, slightly different from greedy search)
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual(predictions[1].tolist(), expected_ids)
|
||||
|
||||
def test_inference_t5(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
@@ -2070,10 +2070,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
self.assertEqual(
|
||||
predictions[0].tolist(),
|
||||
[0, 3, 7, 152, 67, 839, 1],
|
||||
)
|
||||
self.assertEqual(predictions[0].tolist(), [0, 3, 7, 152, 67, 839, 1])
|
||||
self.assertEqual(generated_text, "san diego")
|
||||
|
||||
def test_inference_t5_batched_beam_search(self):
|
||||
|
||||
Reference in New Issue
Block a user