Generate: input expansion for any model input (#21624)
This commit is contained in:
@@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(generated_text, "it's not a city, it's a beach")
|
||||
|
||||
def test_inference_opt_batched(self):
|
||||
def test_inference_opt_batched_beam_search(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
|
||||
|
||||
@@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
image = prepare_img()
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
# Test output
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
|
||||
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
|
||||
# Test output (in this case, slightly different from greedy search)
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
|
||||
def test_inference_t5(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
@@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(generated_text, "san diego")
|
||||
|
||||
def test_inference_t5_batched(self):
|
||||
def test_inference_t5_batched_beam_search(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
|
||||
|
||||
@@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
image = prepare_img()
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
# Test output
|
||||
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
|
||||
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
|
||||
# Test output (in this case, slightly different from greedy search)
|
||||
self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
|
||||
self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
|
||||
|
||||
Reference in New Issue
Block a user