[GIT] Add test for batched generation (#21282)
* Add test * Apply suggestions Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -495,3 +495,22 @@ class GitModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_shape = torch.Size((1, 15))
|
expected_shape = torch.Size((1, 15))
|
||||||
self.assertEqual(generated_ids.shape, expected_shape)
|
self.assertEqual(generated_ids.shape, expected_shape)
|
||||||
self.assertEquals(generated_caption, "what does the front of the bus say at the top? special")
|
self.assertEquals(generated_caption, "what does the front of the bus say at the top? special")
|
||||||
|
|
||||||
|
def test_batched_generation(self):
|
||||||
|
processor = GitProcessor.from_pretrained("microsoft/git-base-coco")
|
||||||
|
model = GitForCausalLM.from_pretrained("microsoft/git-base-coco")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
# create batch of size 2
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
inputs = processor(images=[image, image], return_tensors="pt")
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# we have to prepare `input_ids` with the same batch size as `pixel_values`
|
||||||
|
start_token_id = model.config.bos_token_id
|
||||||
|
generated_ids = model.generate(
|
||||||
|
pixel_values=pixel_values, input_ids=torch.tensor([[start_token_id], [start_token_id]]), max_length=50
|
||||||
|
)
|
||||||
|
generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEquals(generated_captions, ["two cats sleeping on a pink blanket next to remotes."] * 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user