[smolvlm] fix video inference (#39147)
* fix smolvlm * better do as before, set sampling params in overwritten `apply_chat_template` * style * update with `setdefault`
This commit is contained in:
committed by
GitHub
parent
9b2f5b66d8
commit
4d5822e65d
@@ -536,23 +536,24 @@ class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
).content
|
||||
)
|
||||
)
|
||||
self.image2 = Image.open(
|
||||
BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content)
|
||||
)
|
||||
self.image3 = Image.open(
|
||||
BytesIO(
|
||||
requests.get(
|
||||
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
|
||||
).content
|
||||
)
|
||||
)
|
||||
|
||||
self.video_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"path": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/assisted-generation/gif_1_1080p.mov",
|
||||
},
|
||||
{"type": "text", "text": "Describe this video in detail"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
# TODO (Orr?) this is a dummy test to check if the model generates things that make sense.
|
||||
# Needs to be expanded to a tiny video
|
||||
def test_integration_test(self):
|
||||
model = SmolVLMForConditionalGeneration.from_pretrained(
|
||||
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
|
||||
@@ -571,3 +572,26 @@ class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_generated_text = "\n\n\n\nIn this image, we see a view of the Statue of Liberty and the"
|
||||
self.assertEqual(generated_texts[0], expected_generated_text)
|
||||
|
||||
@slow
|
||||
def test_integration_test_video(self):
|
||||
model = SmolVLMForConditionalGeneration.from_pretrained(
|
||||
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Create inputs
|
||||
inputs = self.processor.apply_chat_template(
|
||||
self.video_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(device=torch_device, dtype=torch.bfloat16)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
expected_generated_text = 'User: You are provided the following series of nine frames from a 0:00:09 [H:MM:SS] video.\n\nFrame from 00:00:\nFrame from 00:01:\nFrame from 00:02:\nFrame from 00:03:\nFrame from 00:04:\nFrame from 00:05:\nFrame from 00:06:\nFrame from 00:08:\nFrame from 00:09:\n\nDescribe this video in detail\nAssistant: The video depicts a large language model architecture, specifically a language model with a "quick brown" feature' # fmt: skip
|
||||
self.assertEqual(generated_texts[0], expected_generated_text)
|
||||
|
||||
Reference in New Issue
Block a user