@@ -280,7 +280,7 @@ class IdeficsProcessor(ProcessorMixin):
|
|||||||
else:
|
else:
|
||||||
return fake_token + image_token + fake_token
|
return fake_token + image_token + fake_token
|
||||||
|
|
||||||
all_texts = []
|
all_prompts = []
|
||||||
all_images = []
|
all_images = []
|
||||||
for sample in prompts:
|
for sample in prompts:
|
||||||
# the model was trained on samples starting with <s>
|
# the model was trained on samples starting with <s>
|
||||||
@@ -321,16 +321,17 @@ class IdeficsProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
image_objects = self.image_processor(image_objects, transform=transform)
|
image_objects = self.image_processor(image_objects, transform=transform)
|
||||||
|
|
||||||
|
all_prompts.append(full_text)
|
||||||
|
all_images.append(image_objects)
|
||||||
|
|
||||||
text_encoding = self.tokenizer(
|
text_encoding = self.tokenizer(
|
||||||
text=full_text,
|
text=all_prompts,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
truncation=truncation,
|
truncation=truncation,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
all_texts = text_encoding["input_ids"]
|
||||||
all_texts.append(text_encoding["input_ids"])
|
|
||||||
all_images.append(image_objects)
|
|
||||||
|
|
||||||
max_seq_len = max(len(x) for x in all_texts)
|
max_seq_len = max(len(x) for x in all_texts)
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,25 @@ class IdeficsProcessorTest(TestCasePlus):
|
|||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_tokenizer_padding(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
tokenizer = self.get_tokenizer(padding_side="right")
|
||||||
|
|
||||||
|
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
|
||||||
|
predicted_tokens = [
|
||||||
|
"<s>Describe this image.\nAssistant:<unk><unk><unk><unk><unk><unk><unk><unk><unk>",
|
||||||
|
"<s>Describe this image.\nAssistant:<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>",
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
|
||||||
|
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20)
|
||||||
|
longest = processor(prompts, padding="longest", truncation=True, max_length=30)
|
||||||
|
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
|
||||||
|
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
|
||||||
|
self.assertEqual(decoded_max_length, predicted_tokens[1])
|
||||||
|
self.assertEqual(decoded_longest, predicted_tokens[0])
|
||||||
|
|
||||||
def test_model_input_names(self):
|
def test_model_input_names(self):
|
||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|||||||
Reference in New Issue
Block a user