Generate: correct default model input creation for decoder-only models (#21580)

This commit is contained in:
Joao Gante
2023-02-13 17:04:49 +00:00
committed by GitHub
parent edc1e734bf
commit fa4bdb0a40
4 changed files with 109 additions and 17 deletions

View File

@@ -797,6 +797,20 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "it's not a city, it's a beach")
def test_inference_opt_batched(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
# prepare image
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
# Test output
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
def test_inference_t5(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained(
@@ -827,3 +841,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
[0, 3, 7, 152, 67, 839, 1],
)
self.assertEqual(generated_text, "san diego")
def test_inference_t5_batched(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
# prepare image
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs)
# Test output
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])