@@ -347,8 +347,7 @@ class GitModelTester:
|
|||||||
num_return_sequences=2,
|
num_return_sequences=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parent.assertEqual(generated_ids.shape[0], self.batch_size * 2)
|
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||||
self.parent.assertTrue(generated_ids.shape[1] < 20)
|
|
||||||
|
|
||||||
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
|
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
|
||||||
model = GitForCausalLM(config=config)
|
model = GitForCausalLM(config=config)
|
||||||
|
|||||||
Reference in New Issue
Block a user