Generate: Fix GIT batched captioning (#21738)
This commit is contained in:
@@ -1217,7 +1217,7 @@ class TFGenerationMixin:
|
|||||||
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
|
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
|
||||||
# the attention mask) can rely on the actual model input.
|
# the attention mask) can rely on the actual model input.
|
||||||
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
||||||
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
|
inputs, bos_token_id, model_kwargs=model_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
@@ -1225,9 +1225,7 @@ class TFGenerationMixin:
|
|||||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||||
|
|
||||||
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||||
inputs = self._maybe_initialize_input_ids_for_generation(
|
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
||||||
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs, input_name, model_kwargs
|
return inputs, input_name, model_kwargs
|
||||||
|
|
||||||
@@ -1235,13 +1233,13 @@ class TFGenerationMixin:
|
|||||||
self,
|
self,
|
||||||
inputs: Optional[tf.Tensor] = None,
|
inputs: Optional[tf.Tensor] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[int] = None,
|
||||||
encoder_outputs: Optional[ModelOutput] = None,
|
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
|
||||||
batch_size: Optional[int] = None,
|
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
"""Initializes input ids for generation, if necessary."""
|
"""Initializes input ids for generation, if necessary."""
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
encoder_outputs = model_kwargs.get("encoder_outputs")
|
||||||
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
||||||
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
||||||
shape = encoder_outputs.last_hidden_state.shape[:-1]
|
shape = encoder_outputs.last_hidden_state.shape[:-1]
|
||||||
@@ -1250,7 +1248,13 @@ class TFGenerationMixin:
|
|||||||
if bos_token_id is None:
|
if bos_token_id is None:
|
||||||
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
||||||
|
|
||||||
batch_size = batch_size if batch_size is not None else 1
|
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
|
||||||
|
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
|
||||||
|
batch_size = 1
|
||||||
|
for value in model_kwargs.values():
|
||||||
|
if isinstance(value, tf.Tensor):
|
||||||
|
batch_size = value.shape[0]
|
||||||
|
break
|
||||||
return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id
|
return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -544,7 +544,7 @@ class GenerationMixin:
|
|||||||
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
|
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
|
||||||
# the attention mask) can rely on the actual model input.
|
# the attention mask) can rely on the actual model input.
|
||||||
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
||||||
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
|
inputs, bos_token_id, model_kwargs=model_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
@@ -552,9 +552,7 @@ class GenerationMixin:
|
|||||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||||
|
|
||||||
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||||
inputs = self._maybe_initialize_input_ids_for_generation(
|
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
||||||
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
|
|
||||||
)
|
|
||||||
return inputs, input_name, model_kwargs
|
return inputs, input_name, model_kwargs
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||||
@@ -567,13 +565,13 @@ class GenerationMixin:
|
|||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[int] = None,
|
||||||
encoder_outputs: Optional[ModelOutput] = None,
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
batch_size: Optional[int] = None,
|
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""Initializes input ids for generation, if necessary."""
|
"""Initializes input ids for generation, if necessary."""
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
encoder_outputs = model_kwargs.get("encoder_outputs")
|
||||||
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
||||||
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
||||||
shape = encoder_outputs.last_hidden_state.size()[:-1]
|
shape = encoder_outputs.last_hidden_state.size()[:-1]
|
||||||
@@ -582,7 +580,13 @@ class GenerationMixin:
|
|||||||
if bos_token_id is None:
|
if bos_token_id is None:
|
||||||
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
||||||
|
|
||||||
batch_size = batch_size if batch_size is not None else 1
|
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
|
||||||
|
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
|
||||||
|
batch_size = 1
|
||||||
|
for value in model_kwargs.values():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
batch_size = value.shape[0]
|
||||||
|
break
|
||||||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
||||||
|
|
||||||
def _prepare_attention_mask_for_generation(
|
def _prepare_attention_mask_for_generation(
|
||||||
|
|||||||
@@ -340,6 +340,24 @@ class GitModelTester:
|
|||||||
|
|
||||||
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||||
|
|
||||||
|
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
|
||||||
|
model = GitForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# generate
|
||||||
|
generated_ids = model.generate(
|
||||||
|
input_ids=None, # captioning -> no input_ids
|
||||||
|
attention_mask=None,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
do_sample=False,
|
||||||
|
max_length=20,
|
||||||
|
num_beams=2,
|
||||||
|
num_return_sequences=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
@@ -398,6 +416,10 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester._test_beam_search_generate(*config_and_inputs)
|
self.model_tester._test_beam_search_generate(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_batched_generate_captioning(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester._test_batched_generate_captioning(*config_and_inputs)
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
def test_model_various_embeddings(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user