committed by
GitHub
parent
4ab33c2d81
commit
fad15fba78
@@ -104,14 +104,14 @@ class LlavaProcessor(ProcessorMixin):
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
if images is not None:
|
||||
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
|
||||
image_inputs = self.image_processor(images, return_tensors=return_tensors)
|
||||
else:
|
||||
pixel_values = None
|
||||
image_inputs = {}
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
||||
)
|
||||
|
||||
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
|
||||
@@ -458,3 +458,16 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_OUTPUT = ['<|im_start|>', 'system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', 'user', '\n', '<image>', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', 'ass', 'istant', '\n'] # fmt: skip
|
||||
self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
|
||||
self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_generation_no_images(self):
|
||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
# Prepare inputs with no images
|
||||
inputs = processor("Hello, I am", return_tensors="pt").to(torch_device)
|
||||
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
Reference in New Issue
Block a user