[internvl] fix chat template (#37656)
* fix chat template * update * update conversion * rename `fake_image_token` in tests
This commit is contained in:
committed by
GitHub
parent
9ec8be56dd
commit
1e9087368c
@@ -296,7 +296,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@@ -314,7 +316,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
# Forward
|
||||
@@ -378,8 +382,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
@@ -414,8 +418,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(
|
||||
@@ -485,6 +489,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.float16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
@@ -552,6 +557,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
@@ -601,7 +607,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@@ -619,7 +627,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
# Forward
|
||||
@@ -687,8 +697,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
@@ -724,8 +734,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(
|
||||
@@ -795,6 +805,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.float16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
@@ -862,6 +873,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
@@ -64,7 +64,8 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
**processor_kwargs,
|
||||
)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
cls.image_token = processor.fake_image_token
|
||||
cls.image_token = processor.image_token
|
||||
cls.video_token = processor.video_token
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
@@ -138,6 +139,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
)
|
||||
|
||||
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
|
||||
@@ -150,6 +152,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
)
|
||||
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
|
||||
torch.testing.assert_close(
|
||||
@@ -223,6 +226,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
num_frames=8,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
@@ -272,30 +276,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
|
||||
|
||||
# Load with `video_fps` arg
|
||||
video_fps = 1
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=None, # force to use default num_frames
|
||||
return_tensors="np",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
|
||||
|
||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
# Load with `video_fps` arg is not possible with InternVL (skip)
|
||||
|
||||
# Load without any arg should use the default loading method
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
@@ -305,8 +286,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 300)
|
||||
|
||||
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||
# because we assume they come from one video
|
||||
|
||||
Reference in New Issue
Block a user