[gemma3] fix bidirectional image mask (#39396)
* fix gemma3 mask * make compile happy, and use only torch ops * no full attention between images * update tests * fix tests * add a fast test
This commit is contained in:
committed by
GitHub
parent
fbeaf96f9e
commit
3bc726b381
@@ -646,8 +646,8 @@ class GenerationMixin(ContinuousMixin):
|
|||||||
|
|
||||||
# If it's not defined, it means the model uses the new general mask API
|
# If it's not defined, it means the model uses the new general mask API
|
||||||
if causal_mask_creation_function is None: # can't be found
|
if causal_mask_creation_function is None: # can't be found
|
||||||
token_type_ids = getattr(model_input, "token_type_ids", None)
|
token_type_ids = model_inputs.get("token_type_ids", None)
|
||||||
position_ids = getattr(model_input, position_ids_key, None)
|
position_ids = model_inputs.get(position_ids_key, None)
|
||||||
# Some models may overwrite the general one
|
# Some models may overwrite the general one
|
||||||
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
|
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
|
||||||
attention_mask = causal_mask_creation_function(
|
attention_mask = causal_mask_creation_function(
|
||||||
|
|||||||
@@ -737,7 +737,11 @@ class Gemma3MultiModalProjector(nn.Module):
|
|||||||
return projected_vision_outputs.type_as(vision_outputs)
|
return projected_vision_outputs.type_as(vision_outputs)
|
||||||
|
|
||||||
|
|
||||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
|
def token_type_ids_mask_function(
|
||||||
|
token_type_ids: Optional[torch.Tensor],
|
||||||
|
image_group_ids: Optional[torch.Tensor],
|
||||||
|
tokens_per_image: int,
|
||||||
|
) -> Optional[Callable]:
|
||||||
"""
|
"""
|
||||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||||
not start and end indices.
|
not start and end indices.
|
||||||
@@ -747,10 +751,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||||
# If the difference is less than image size, both are part of the same image block
|
|
||||||
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
|
|
||||||
# If it's 1 for both query and key/value, we are in an image block
|
# If it's 1 for both query and key/value, we are in an image block
|
||||||
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
|
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
||||||
|
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
||||||
|
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||||
|
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
|
||||||
|
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
||||||
|
|
||||||
|
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
|
||||||
|
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
||||||
|
|
||||||
|
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
|
||||||
|
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
|
||||||
|
|
||||||
# This is bidirectional attention whenever we are dealing with image tokens
|
# This is bidirectional attention whenever we are dealing with image tokens
|
||||||
return is_image_block & same_image_block
|
return is_image_block & same_image_block
|
||||||
@@ -915,8 +927,15 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|||||||
}
|
}
|
||||||
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
||||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||||
|
|
||||||
|
# First find where a new image block starts: 1 if image and previous not image
|
||||||
|
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||||
|
is_image = (token_type_ids == 1).to(cache_position.device)
|
||||||
|
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||||
|
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||||
|
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||||
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
|
token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the masks
|
# Create the masks
|
||||||
@@ -1181,8 +1200,15 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|||||||
# Add the token type ids mask for generate as well
|
# Add the token type ids mask for generate as well
|
||||||
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
||||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||||
|
|
||||||
|
# First find where a new image block starts: 1 if image and previous not image
|
||||||
|
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||||
|
is_image = (token_type_ids == 1).to(cache_position.device)
|
||||||
|
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||||
|
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||||
|
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||||
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
|
token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
|
||||||
)
|
)
|
||||||
|
|
||||||
return create_masks_for_generate(**mask_kwargs)
|
return create_masks_for_generate(**mask_kwargs)
|
||||||
|
|||||||
@@ -716,7 +716,11 @@ class Gemma3MultiModalProjector(nn.Module):
|
|||||||
return projected_vision_outputs.type_as(vision_outputs)
|
return projected_vision_outputs.type_as(vision_outputs)
|
||||||
|
|
||||||
|
|
||||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
|
def token_type_ids_mask_function(
|
||||||
|
token_type_ids: Optional[torch.Tensor],
|
||||||
|
image_group_ids: Optional[torch.Tensor],
|
||||||
|
tokens_per_image: int,
|
||||||
|
) -> Optional[Callable]:
|
||||||
"""
|
"""
|
||||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||||
not start and end indices.
|
not start and end indices.
|
||||||
@@ -726,10 +730,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||||
# If the difference is less than image size, both are part of the same image block
|
|
||||||
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
|
|
||||||
# If it's 1 for both query and key/value, we are in an image block
|
# If it's 1 for both query and key/value, we are in an image block
|
||||||
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
|
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
||||||
|
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
||||||
|
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||||
|
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
|
||||||
|
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
||||||
|
|
||||||
|
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
|
||||||
|
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
||||||
|
|
||||||
|
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
|
||||||
|
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
|
||||||
|
|
||||||
# This is bidirectional attention whenever we are dealing with image tokens
|
# This is bidirectional attention whenever we are dealing with image tokens
|
||||||
return is_image_block & same_image_block
|
return is_image_block & same_image_block
|
||||||
@@ -840,8 +852,15 @@ class Gemma3Model(PaliGemmaModel):
|
|||||||
}
|
}
|
||||||
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
||||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||||
|
|
||||||
|
# First find where a new image block starts: 1 if image and previous not image
|
||||||
|
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||||
|
is_image = (token_type_ids == 1).to(cache_position.device)
|
||||||
|
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||||
|
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||||
|
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||||
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
|
token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the masks
|
# Create the masks
|
||||||
@@ -1062,8 +1081,15 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
|||||||
# Add the token type ids mask for generate as well
|
# Add the token type ids mask for generate as well
|
||||||
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
||||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||||
|
|
||||||
|
# First find where a new image block starts: 1 if image and previous not image
|
||||||
|
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||||
|
is_image = (token_type_ids == 1).to(cache_position.device)
|
||||||
|
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||||
|
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||||
|
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||||
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
|
token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
|
||||||
)
|
)
|
||||||
|
|
||||||
return create_masks_for_generate(**mask_kwargs)
|
return create_masks_for_generate(**mask_kwargs)
|
||||||
|
|||||||
@@ -270,6 +270,45 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
self.model_tester = Gemma3Vision2TextModelTester(self)
|
self.model_tester = Gemma3Vision2TextModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
|
||||||
|
|
||||||
|
def test_bidirectional_image_attention(self):
|
||||||
|
"""
|
||||||
|
Tests that each image can attend to itself bidirectionally. However an image
|
||||||
|
cannot attend to future images, even within the same batch.
|
||||||
|
"""
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config._attn_implementation = "eager"
|
||||||
|
model = Gemma3Model(config).to(torch_device)
|
||||||
|
|
||||||
|
# First let's pass inputs without change which is one image per text and manipulate
|
||||||
|
# `token_type_ids` to make sure bidirectional mask is applied where it has to be
|
||||||
|
inputs_dict["token_type_ids"] = torch.zeros_like(inputs_dict["token_type_ids"])
|
||||||
|
inputs_dict["token_type_ids"][:, :4] = 1 # unmask first 4 tokens
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(**inputs_dict, output_attentions=True)
|
||||||
|
# We expect a non-causal mask on first 4 tokens, thus no zeros
|
||||||
|
for attention in out.attentions:
|
||||||
|
self.assertTrue((attention[..., :4, :4] != 0).all().item())
|
||||||
|
|
||||||
|
# Now when removing `token_type_ids`, we will get simple causal mask
|
||||||
|
inputs_dict["token_type_ids"][:, :4] = 0 # mask back first 4 tokens
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(**inputs_dict, output_attentions=True)
|
||||||
|
# We expect a causal mask on first 4 tokens, thus no zeros
|
||||||
|
for attention in out.attentions:
|
||||||
|
self.assertFalse((attention[..., :4, :4] != 0).all().item())
|
||||||
|
|
||||||
|
# Let's add two "images" per text, first one spanning 4 tokens and last one 3 tokens
|
||||||
|
inputs_dict["token_type_ids"][:, :4] = 1
|
||||||
|
inputs_dict["token_type_ids"][:, 7:10] = 1
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(**inputs_dict, output_attentions=True)
|
||||||
|
for attention in out.attentions:
|
||||||
|
self.assertTrue((attention[..., :4, :4] != 0).all().item())
|
||||||
|
self.assertTrue((attention[..., 7:10, 7:10] != 0).all().item())
|
||||||
|
|
||||||
|
# We expect a non-causal mask only within same image and no looking ahead to the future
|
||||||
|
self.assertTrue((attention[..., :4, 7:10] == 0).all().item())
|
||||||
|
|
||||||
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
||||||
def test_training_gradient_checkpointing(self):
|
def test_training_gradient_checkpointing(self):
|
||||||
pass
|
pass
|
||||||
@@ -413,7 +452,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_TEXTS = Expectations(
|
EXPECTED_TEXTS = Expectations(
|
||||||
{
|
{
|
||||||
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'],
|
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'],
|
||||||
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
|
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'],
|
||||||
("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
|
("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
|
||||||
}
|
}
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
@@ -463,8 +502,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
("cuda", 8):
|
("cuda", 8):
|
||||||
[
|
[
|
||||||
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks',
|
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like',
|
||||||
'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes. \n\n* **Image 1** shows a cow standing on a beach'
|
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown"
|
||||||
],
|
],
|
||||||
("rocm", (9, 5)):
|
("rocm", (9, 5)):
|
||||||
[
|
[
|
||||||
@@ -508,7 +547,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
|
("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
|
||||||
("cuda", 7): [],
|
("cuda", 7): [],
|
||||||
("cuda", 8): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
|
("cuda", 8): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the"],
|
||||||
}
|
}
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
||||||
@@ -565,8 +604,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
("cuda", 7): [],
|
("cuda", 7): [],
|
||||||
("cuda", 8): [
|
("cuda", 8): [
|
||||||
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
|
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the",
|
||||||
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a',
|
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a'
|
||||||
],
|
],
|
||||||
("rocm", (9, 5)) : [
|
("rocm", (9, 5)) : [
|
||||||
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
|
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
|
||||||
@@ -610,7 +649,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"],
|
("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"],
|
||||||
("cuda", 7): [],
|
("cuda", 7): [],
|
||||||
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"],
|
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"],
|
||||||
}
|
}
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
||||||
|
|||||||
Reference in New Issue
Block a user