Fix Llava for 0-embeddings (#30473)

This commit is contained in:
Raushan Turganbay
2024-04-25 20:28:51 +05:00
committed by GitHub
parent ad697f1801
commit e60491adc9
4 changed files with 41 additions and 6 deletions

View File

@@ -459,3 +459,29 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_small_model_integration_test_unk_token(self):
# related to (#29835)
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
prompt_with_unk = "[INST] <image>\nWhat is shown in this <unk> image? [/INST]"
inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt")
# verify single forward pass
inputs = inputs.to(torch_device)
with torch.no_grad():
output = model(**inputs)
# verify generation
output = model.generate(**inputs, max_new_tokens=40)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)