[InstructBlip] Add instruct blip int8 test (#24555)
* add 8bit instructblip test * update tests
This commit is contained in:
@@ -521,51 +521,39 @@ def prepare_img():
|
||||
@require_torch
|
||||
@slow
|
||||
class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
# TODO (@Younes): Re-enable this when 8-bit or 4-bit is implemented.
|
||||
@unittest.skip(reason="GPU OOM")
|
||||
def test_inference_vicuna_7b(self):
|
||||
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
||||
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b").to(
|
||||
torch_device
|
||||
model = InstructBlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/instructblip-vicuna-7b", load_in_8bit=True
|
||||
)
|
||||
|
||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
prompt = "What is unusual about this image?"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
|
||||
# verify logits
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-3.4684, -12.6759, 8.5067], [-5.1305, -12.2058, 7.9834], [-4.0632, -13.9285, 9.2327]],
|
||||
[[-3.5410, -12.2812, 8.2812], [-5.2500, -12.0938, 7.8398], [-4.1523, -13.8281, 9.0000]],
|
||||
device=torch_device,
|
||||
)
|
||||
assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-5)
|
||||
self.assertTrue(torch.allclose(logits[0, :3, :3].float(), expected_slice, atol=1e-3))
|
||||
|
||||
# verify generation
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
num_beams=5,
|
||||
max_length=256,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.5,
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
outputs = model.generate(**inputs, max_new_tokens=30)
|
||||
outputs[outputs == 0] = 2
|
||||
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# fmt: off
|
||||
expected_outputs = [2, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 29892, 607, 338, 14089, 287, 297, 278, 7256, 310, 263, 19587, 4272, 11952, 29889, 910, 338, 385, 443, 535, 794, 1848, 2948, 304, 13977, 292, 22095, 29892, 408, 372, 6858, 278, 767, 304, 17346, 3654, 322, 670, 13977, 292, 21083, 373, 2246, 310, 278, 19716, 1550, 12402, 1218, 1549, 12469, 29889, 19814, 29892, 278, 10122, 310, 8818, 275, 322, 916, 24413, 297, 278, 9088, 4340, 19310, 7093, 278, 22910, 5469, 310, 445, 6434, 29889, 2, 1]
|
||||
expected_outputs = [ 2, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 1623, 263, 19587, 4272, 11952, 29889]
|
||||
# fmt: on
|
||||
self.assertEqual(outputs[0].tolist(), expected_outputs)
|
||||
self.assertEqual(
|
||||
generated_text,
|
||||
"The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.",
|
||||
"The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving down a busy city street.",
|
||||
)
|
||||
|
||||
def test_inference_flant5_xl(self):
|
||||
|
||||
Reference in New Issue
Block a user