Fix number of patch check for different vision feature select strategy (#32494)

* Fix number of patch check for different vision feature select strategy

* add test

---------

Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
Insu Jang
2024-09-17 03:33:07 -04:00
committed by GitHub
parent 18e1a9c719
commit bcf8946f0a
2 changed files with 32 additions and 2 deletions

View File

@@ -620,3 +620,24 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_small_model_integration_test_full_vision_state_selection(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
# test that changing `strategy` won't error out
model.vision_feature_select_strategy = "full"
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
# verify generation
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)