Gemma3 processor typo (#36710)
* fix typo when is on * tiny * add test and remove 'text_crops' * lint
This commit is contained in:
@@ -384,7 +384,7 @@ class Gemma3ImageProcessor(BaseImageProcessor):
|
|||||||
images_list = [images for images, _ in images_list_and_num_crops]
|
images_list = [images for images, _ in images_list_and_num_crops]
|
||||||
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
|
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
|
||||||
else:
|
else:
|
||||||
num_crops = [[0] for images in images_list]
|
num_crops = [[0] for _ in images_list]
|
||||||
|
|
||||||
processed_images = []
|
processed_images = []
|
||||||
for images in images_list:
|
for images in images_list:
|
||||||
|
|||||||
@@ -113,7 +113,6 @@ class Gemma3Processor(ProcessorMixin):
|
|||||||
|
|
||||||
# Replace image tokens by the full expanded sequence
|
# Replace image tokens by the full expanded sequence
|
||||||
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
||||||
text_with_crops = text
|
|
||||||
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
||||||
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
||||||
|
|
||||||
@@ -130,7 +129,7 @@ class Gemma3Processor(ProcessorMixin):
|
|||||||
+ " ".join([self.boi_token] * num)
|
+ " ".join([self.boi_token] * num)
|
||||||
)
|
)
|
||||||
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
||||||
text_with_crops[batch_idx] = prompt
|
text[batch_idx] = prompt
|
||||||
|
|
||||||
# Expand placeholder image tokens to the full image token sequence
|
# Expand placeholder image tokens to the full image token sequence
|
||||||
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
||||||
|
|||||||
@@ -417,6 +417,39 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
def test_model_4b_crops(self):
|
||||||
|
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||||
|
|
||||||
|
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
crop_config = {
|
||||||
|
"images_kwargs": {
|
||||||
|
"do_pan_and_scan": True,
|
||||||
|
"pan_and_scan_max_num_crops": 448,
|
||||||
|
"pan_and_scan_min_crop_size": 32,
|
||||||
|
"pan_and_scan_min_ratio_to_activate": 0.3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
self.messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_generation_prompt=True,
|
||||||
|
**crop_config,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
|
||||||
|
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nDescribe this image in detail.\nmodel\nHere's a detailed description of the image:\n\n**Overall Impression:**\n\nThe image is a close-up shot of a garden scene featuring several"] # fmt: skip
|
||||||
|
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
def test_model_4b_multiimage(self):
|
def test_model_4b_multiimage(self):
|
||||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user