Fix number of patch check for different vision feature select strategy (#32494)
* Fix number of patch check for different vision feature select strategy * add test --------- Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
@@ -620,3 +620,24 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_full_vision_state_selection(self):
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
# test that changing `strategy` won't error out
|
||||
model.vision_feature_select_strategy = "full"
|
||||
|
||||
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=30)
|
||||
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes' # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user