VLMs: major clean up 🧼 (#34502)

only lllava models are modified
This commit is contained in:
Raushan Turganbay
2025-01-08 10:35:23 +01:00
committed by GitHub
parent 7176e06b52
commit d1681ec2b6
19 changed files with 197 additions and 1028 deletions

View File

@@ -327,10 +327,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
@@ -378,7 +375,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20)
@@ -402,7 +399,9 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to(
torch_device
)
output = model.generate(**inputs, max_new_tokens=20)
@@ -434,7 +433,9 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True)
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True).to(
torch_device
)
output = model.generate(**inputs, max_new_tokens=20)
@@ -508,32 +509,18 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
# Simulate some user inputs
pixel_values = torch.randn(
(1, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)
prompt = "USER: <image>\nDescribe the imageASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
# Make sure that the loss is properly computed
loss = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
**inputs,
labels=inputs.input_ids.clone(),
).loss
loss.backward()
@@ -593,38 +580,6 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = "user\n\nWhat are these?\nassistant The image shows two cats, one on the left and one on the right. They appear to be resting or sleeping on a pink blanket. The cat"
self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_expansion_in_processing(self):
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image>\nDescribe the image:\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
# check processing with expansion of inputs
processor.vision_feature_select_strategy = "default"
processor.num_additional_image_tokens = 1
processor.patch_size = 14
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
processor.num_additional_image_tokens = None
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs.input_ids.shape[-1] == 18)
# generate exactly 20 tokens
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_pixtral(self):