Generate: handle text conditioning with multimodal encoder-decoder models (#22748)
This commit is contained in:
@@ -69,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str
|
||||
## Pix2StructForConditionalGeneration
|
||||
|
||||
[[autodoc]] Pix2StructForConditionalGeneration
|
||||
- forward
|
||||
- forward
|
||||
|
||||
@@ -837,12 +837,12 @@ class TFGenerationMixin:
|
||||
|
||||
# 6. Prepare model inputs which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
@@ -1095,16 +1095,41 @@ class TFGenerationMixin:
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
batch_size: int,
|
||||
model_input_name: str,
|
||||
model_kwargs: Dict[str, tf.Tensor],
|
||||
decoder_start_token_id: int = None,
|
||||
bos_token_id: int = None,
|
||||
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
|
||||
) -> tf.Tensor:
|
||||
# prepare `input_ids` for decoder if model is encoder-decoder
|
||||
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
|
||||
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
||||
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
|
||||
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
|
||||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
||||
return model_kwargs.pop("decoder_input_ids")
|
||||
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
||||
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
|
||||
decoder_input_ids = model_kwargs.pop("input_ids")
|
||||
else:
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
return tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
|
||||
decoder_input_ids = None
|
||||
|
||||
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
decoder_input_ids_start = tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
|
||||
|
||||
# no user input -> use decoder_start_token_id as decoder_input_ids
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = decoder_input_ids_start
|
||||
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
||||
# decoder_attention_mask if provided)
|
||||
elif tf.reduce_all(decoder_input_ids[:, 0] != decoder_start_token_id):
|
||||
decoder_input_ids = tf.concat([decoder_input_ids_start, decoder_input_ids], axis=-1)
|
||||
if "decoder_attention_mask" in model_kwargs:
|
||||
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||
decoder_attention_mask = tf.concat(
|
||||
(tf.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
|
||||
axis=-1,
|
||||
)
|
||||
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
||||
|
||||
return decoder_input_ids, model_kwargs
|
||||
|
||||
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
||||
# retrieve decoder_start_token_id for encoder-decoder models
|
||||
|
||||
@@ -642,18 +642,44 @@ class GenerationMixin:
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
batch_size: int,
|
||||
model_input_name: str,
|
||||
model_kwargs: Dict[str, torch.Tensor],
|
||||
decoder_start_token_id: int = None,
|
||||
bos_token_id: int = None,
|
||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
device: torch.device = None,
|
||||
) -> torch.LongTensor:
|
||||
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
|
||||
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
||||
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
|
||||
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
|
||||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
||||
return model_kwargs.pop("decoder_input_ids")
|
||||
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
||||
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
|
||||
decoder_input_ids = model_kwargs.pop("input_ids")
|
||||
else:
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
if device is None:
|
||||
device = self.device
|
||||
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
||||
decoder_input_ids = None
|
||||
|
||||
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
if device is None:
|
||||
device = self.device
|
||||
decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
||||
|
||||
# no user input -> use decoder_start_token_id as decoder_input_ids
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = decoder_input_ids_start
|
||||
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
||||
# decoder_attention_mask if provided)
|
||||
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
|
||||
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
|
||||
if "decoder_attention_mask" in model_kwargs:
|
||||
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||
decoder_attention_mask = torch.cat(
|
||||
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
|
||||
dim=-1,
|
||||
)
|
||||
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
||||
|
||||
return decoder_input_ids, model_kwargs
|
||||
|
||||
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
||||
decoder_start_token_id = (
|
||||
@@ -1289,17 +1315,14 @@ class GenerationMixin:
|
||||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
model_kwargs=model_kwargs,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
|
||||
# conditional generation for multi-modal models.
|
||||
if "input_ids" in model_kwargs and model_input_name == "pixel_values":
|
||||
input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
|
||||
else:
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
|
||||
@@ -1776,35 +1776,6 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(input_ids, torch.Tensor):
|
||||
# check if the first element of `input_ids` is equal to `input_ids`:
|
||||
if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item():
|
||||
# add `input_ids` as first token to `input_ids`
|
||||
input_ids = torch.cat(
|
||||
[
|
||||
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
|
||||
* self.config.decoder_start_token_id,
|
||||
input_ids,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if decoder_attention_mask is not None:
|
||||
decoder_attention_mask = torch.cat(
|
||||
[
|
||||
torch.ones(
|
||||
(decoder_attention_mask.shape[0], 1),
|
||||
dtype=torch.long,
|
||||
device=decoder_attention_mask.device,
|
||||
),
|
||||
decoder_attention_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
elif input_ids is None:
|
||||
batch_size = flattened_patches.shape[0]
|
||||
input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device)
|
||||
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
|
||||
|
||||
|
||||
@@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin:
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
# 1 BOS + 29 (input length) + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 33])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
@@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin:
|
||||
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
|
||||
)
|
||||
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
|
||||
|
||||
def test_generate_vision2text_conditioning(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
|
||||
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
|
||||
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||
conditioning_input = create_tensor_fn([[10], [10]]) # this should be the 2nd output token, after the BOS token
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
|
||||
if is_pt:
|
||||
pixel_values = pixel_values.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
conditioning_input = conditioning_input.to(torch_device)
|
||||
|
||||
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
|
||||
# decoder_input_ids, if the encoder is not a model with text input)
|
||||
output_sequences_decoder_input_ids = model.generate(
|
||||
pixel_values, max_length=5, decoder_input_ids=conditioning_input
|
||||
)
|
||||
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input)
|
||||
if is_pt:
|
||||
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy()
|
||||
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy()
|
||||
conditioning_input = conditioning_input.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||
|
||||
@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
# Greedy
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user