Blip dynamic input resolution (#30722)

* blip with interpolated pos encoding

* feat: Add interpolate_pos_encoding option to other models from `BLIP` family.

* include check for textual generated content in tests
This commit is contained in:
Zafir Stojanovski
2024-05-13 13:20:16 +02:00
committed by GitHub
parent a4e530e3c8
commit f63d822242
6 changed files with 240 additions and 20 deletions

View File

@@ -882,6 +882,22 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(generated_text, "it's not a city, it's a beach")
def test_inference_interpolate_pos_encoding(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
).to(torch_device)
processor.image_processor.size = {"height": 500, "width": 500}
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118])
self.assertEqual(generated_text, "a woman and dog on the beach")
def test_inference_opt_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(