Blip dynamic input resolution (#30722)

* blip with interpolated pos encoding

* feat: Add interpolate_pos_encoding option to other models from `BLIP` family.

* include check for textual generated content in tests
This commit is contained in:
Zafir Stojanovski
2024-05-13 13:20:16 +02:00
committed by GitHub
parent a4e530e3c8
commit f63d822242
6 changed files with 240 additions and 20 deletions

View File

@@ -1381,6 +1381,20 @@ class BlipModelIntegrationTest(unittest.TestCase):
[30522, 1037, 3861, 1997, 1037, 2450, 1998, 2014, 3899, 2006, 1996, 3509, 102],
)
def test_inference_interpolate_pos_encoding(self):
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(torch_device)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
processor.image_processor.size = {"height": 500, "width": 500}
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device)
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
self.assertEqual(predictions[0].tolist(), [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 1037, 3899, 102])
self.assertEqual(generated_text, "a woman sitting on the beach with a dog")
def test_inference_vqa(self):
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(torch_device)
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")