Generate: Fix GIT batched captioning (#21738)
This commit is contained in:
@@ -340,6 +340,24 @@ class GitModelTester:
|
||||
|
||||
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||
|
||||
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
|
||||
model = GitForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# generate
|
||||
generated_ids = model.generate(
|
||||
input_ids=None, # captioning -> no input_ids
|
||||
attention_mask=None,
|
||||
pixel_values=pixel_values,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
@@ -398,6 +416,10 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester._test_beam_search_generate(*config_and_inputs)
|
||||
|
||||
def test_batched_generate_captioning(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester._test_batched_generate_captioning(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
|
||||
Reference in New Issue
Block a user