BLIPs clean-up (#35560)
* blips clean up * update processor * readability * fix processor length * fix copies * tmp * update and fix copies * why keep these, delete? * fix test fetcher * irrelevant comment * fix tests * fix tests * fix copies
This commit is contained in:
committed by
GitHub
parent
4f8f51be4e
commit
75794792ad
@@ -809,34 +809,3 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1]
|
||||
)
|
||||
self.assertEqual(generated_text, "The image features a woman sitting on the beach with a dog.")
|
||||
|
||||
def test_expansion_in_processing(self):
|
||||
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
|
||||
model = InstructBlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/instructblip-flan-t5-xl",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
prompt = "What's in the image?"
|
||||
|
||||
# Make sure we will go the legacy path by setting these args to None
|
||||
processor.num_query_tokens = None
|
||||
model.config.image_token_index = None
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
||||
processor.num_query_tokens = model.config.num_query_tokens
|
||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
model.config.image_token_index = len(processor.tokenizer) - 2
|
||||
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
|
||||
|
||||
# Generate again with new inputs
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
predictions_expanded = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
||||
|
||||
self.assertTrue(generated_text_expanded == generated_text)
|
||||
|
||||
Reference in New Issue
Block a user