Generate: Fix GIT batched captioning (#21738)

This commit is contained in:
Joao Gante
2023-02-23 09:50:37 +00:00
committed by GitHub
parent 78a93d17c0
commit 1d4b797852
3 changed files with 44 additions and 14 deletions

View File

@@ -340,6 +340,24 @@ class GitModelTester:
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
model = GitForCausalLM(config=config)
model.to(torch_device)
model.eval()
# generate
generated_ids = model.generate(
input_ids=None, # captioning -> no input_ids
attention_mask=None,
pixel_values=pixel_values,
do_sample=False,
max_length=20,
num_beams=2,
num_return_sequences=2,
)
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@@ -398,6 +416,10 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester._test_beam_search_generate(*config_and_inputs)
def test_batched_generate_captioning(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester._test_batched_generate_captioning(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]: