Generate: input expansion for any model input (#21624)

This commit is contained in:
Joao Gante
2023-02-14 14:16:22 +00:00
committed by GitHub
parent 13e03e619d
commit a81fe4e1df
3 changed files with 58 additions and 67 deletions

View File

@@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "it's not a city, it's a beach")
def test_inference_opt_batched(self):
def test_inference_opt_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
@@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
predictions = model.generate(**inputs, num_beams=2)
# Test output
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
def test_inference_t5(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
@@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "san diego")
def test_inference_t5_batched(self):
def test_inference_t5_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
@@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
predictions = model.generate(**inputs, num_beams=2)
# Test output
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])