Generate: handle text conditioning with multimodal encoder-decoder models (#22748)
This commit is contained in:
@@ -837,12 +837,12 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# 6. Prepare model inputs which will be used for auto-regressive generation
|
# 6. Prepare model inputs which will be used for auto-regressive generation
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
|
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
batch_size=batch_size,
|
||||||
batch_size,
|
model_input_name=model_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||||
bos_token_id=generation_config.bos_token_id,
|
bos_token_id=generation_config.bos_token_id,
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
@@ -1095,16 +1095,41 @@ class TFGenerationMixin:
|
|||||||
def _prepare_decoder_input_ids_for_generation(
|
def _prepare_decoder_input_ids_for_generation(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
model_input_name: str,
|
||||||
|
model_kwargs: Dict[str, tf.Tensor],
|
||||||
decoder_start_token_id: int = None,
|
decoder_start_token_id: int = None,
|
||||||
bos_token_id: int = None,
|
bos_token_id: int = None,
|
||||||
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
|
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
|
||||||
) -> tf.Tensor:
|
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
||||||
# prepare `input_ids` for decoder if model is encoder-decoder
|
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
|
||||||
|
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
|
||||||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
||||||
return model_kwargs.pop("decoder_input_ids")
|
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
||||||
|
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
|
||||||
|
decoder_input_ids = model_kwargs.pop("input_ids")
|
||||||
else:
|
else:
|
||||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
decoder_input_ids = None
|
||||||
return tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
|
|
||||||
|
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
||||||
|
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||||
|
decoder_input_ids_start = tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
|
||||||
|
|
||||||
|
# no user input -> use decoder_start_token_id as decoder_input_ids
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
decoder_input_ids = decoder_input_ids_start
|
||||||
|
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
||||||
|
# decoder_attention_mask if provided)
|
||||||
|
elif tf.reduce_all(decoder_input_ids[:, 0] != decoder_start_token_id):
|
||||||
|
decoder_input_ids = tf.concat([decoder_input_ids_start, decoder_input_ids], axis=-1)
|
||||||
|
if "decoder_attention_mask" in model_kwargs:
|
||||||
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||||
|
decoder_attention_mask = tf.concat(
|
||||||
|
(tf.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
||||||
|
|
||||||
|
return decoder_input_ids, model_kwargs
|
||||||
|
|
||||||
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
||||||
# retrieve decoder_start_token_id for encoder-decoder models
|
# retrieve decoder_start_token_id for encoder-decoder models
|
||||||
|
|||||||
@@ -642,18 +642,44 @@ class GenerationMixin:
|
|||||||
def _prepare_decoder_input_ids_for_generation(
|
def _prepare_decoder_input_ids_for_generation(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
model_input_name: str,
|
||||||
|
model_kwargs: Dict[str, torch.Tensor],
|
||||||
decoder_start_token_id: int = None,
|
decoder_start_token_id: int = None,
|
||||||
bos_token_id: int = None,
|
bos_token_id: int = None,
|
||||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
) -> torch.LongTensor:
|
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
|
||||||
|
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
||||||
|
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
|
||||||
|
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
|
||||||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
||||||
return model_kwargs.pop("decoder_input_ids")
|
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
||||||
|
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
|
||||||
|
decoder_input_ids = model_kwargs.pop("input_ids")
|
||||||
else:
|
else:
|
||||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
decoder_input_ids = None
|
||||||
if device is None:
|
|
||||||
device = self.device
|
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
||||||
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||||
|
if device is None:
|
||||||
|
device = self.device
|
||||||
|
decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
||||||
|
|
||||||
|
# no user input -> use decoder_start_token_id as decoder_input_ids
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
decoder_input_ids = decoder_input_ids_start
|
||||||
|
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
||||||
|
# decoder_attention_mask if provided)
|
||||||
|
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
|
||||||
|
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
|
||||||
|
if "decoder_attention_mask" in model_kwargs:
|
||||||
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||||
|
decoder_attention_mask = torch.cat(
|
||||||
|
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
||||||
|
|
||||||
|
return decoder_input_ids, model_kwargs
|
||||||
|
|
||||||
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
||||||
decoder_start_token_id = (
|
decoder_start_token_id = (
|
||||||
@@ -1289,17 +1315,14 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||||
batch_size,
|
batch_size=batch_size,
|
||||||
|
model_input_name=model_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||||
bos_token_id=generation_config.bos_token_id,
|
bos_token_id=generation_config.bos_token_id,
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
device=inputs_tensor.device,
|
device=inputs_tensor.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# conditional generation for multi-modal models.
|
|
||||||
if "input_ids" in model_kwargs and model_input_name == "pixel_values":
|
|
||||||
input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
|
|
||||||
else:
|
else:
|
||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
|
|
||||||
|
|||||||
@@ -1776,35 +1776,6 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if isinstance(input_ids, torch.Tensor):
|
|
||||||
# check if the first element of `input_ids` is equal to `input_ids`:
|
|
||||||
if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item():
|
|
||||||
# add `input_ids` as first token to `input_ids`
|
|
||||||
input_ids = torch.cat(
|
|
||||||
[
|
|
||||||
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
|
|
||||||
* self.config.decoder_start_token_id,
|
|
||||||
input_ids,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decoder_attention_mask is not None:
|
|
||||||
decoder_attention_mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.ones(
|
|
||||||
(decoder_attention_mask.shape[0], 1),
|
|
||||||
dtype=torch.long,
|
|
||||||
device=decoder_attention_mask.device,
|
|
||||||
),
|
|
||||||
decoder_attention_mask,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
elif input_ids is None:
|
|
||||||
batch_size = flattened_patches.shape[0]
|
|
||||||
input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device)
|
|
||||||
|
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
|
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
|
||||||
|
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin:
|
|||||||
|
|
||||||
# Decoder only call
|
# Decoder only call
|
||||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||||
# 29 + 3 new tokens
|
# 1 BOS + 29 (input length) + 3 new tokens
|
||||||
self.assertEqual(list(outputs.shape), [1, 32])
|
self.assertEqual(list(outputs.shape), [1, 33])
|
||||||
|
|
||||||
# Encoder decoder call > 20
|
# Encoder decoder call > 20
|
||||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||||
@@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin:
|
|||||||
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
|
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
|
||||||
)
|
)
|
||||||
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
|
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
|
||||||
|
|
||||||
|
def test_generate_vision2text_conditioning(self):
|
||||||
|
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
|
||||||
|
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
|
||||||
|
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||||
|
is_pt = not model_cls.__name__.startswith("TF")
|
||||||
|
|
||||||
|
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||||
|
conditioning_input = create_tensor_fn([[10], [10]]) # this should be the 2nd output token, after the BOS token
|
||||||
|
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
|
||||||
|
if is_pt:
|
||||||
|
pixel_values = pixel_values.to(torch_device)
|
||||||
|
model = model.to(torch_device)
|
||||||
|
conditioning_input = conditioning_input.to(torch_device)
|
||||||
|
|
||||||
|
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
|
||||||
|
# decoder_input_ids, if the encoder is not a model with text input)
|
||||||
|
output_sequences_decoder_input_ids = model.generate(
|
||||||
|
pixel_values, max_length=5, decoder_input_ids=conditioning_input
|
||||||
|
)
|
||||||
|
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input)
|
||||||
|
if is_pt:
|
||||||
|
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy()
|
||||||
|
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy()
|
||||||
|
conditioning_input = conditioning_input.cpu().numpy()
|
||||||
|
|
||||||
|
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||||
|
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||||
|
|||||||
@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
max_length = 20
|
max_length = 20
|
||||||
input_ids = input_ids.expand(2, -1)
|
input_ids = input_ids.expand(2, -1)
|
||||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||||
input_ids.shape[0],
|
batch_size=input_ids.shape[0],
|
||||||
|
model_input_name=bart_model.main_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
max_length = 20
|
max_length = 20
|
||||||
input_ids = input_ids.expand(2, -1)
|
input_ids = input_ids.expand(2, -1)
|
||||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||||
input_ids.shape[0],
|
batch_size=input_ids.shape[0],
|
||||||
|
model_input_name=bart_model.main_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
input_ids = input_ids.expand(2, -1)
|
input_ids = input_ids.expand(2, -1)
|
||||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||||
input_ids.shape[0],
|
batch_size=input_ids.shape[0],
|
||||||
|
model_input_name=bart_model.main_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
input_ids = input_ids.expand(6, -1)
|
input_ids = input_ids.expand(6, -1)
|
||||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||||
input_ids.shape[0],
|
batch_size=input_ids.shape[0],
|
||||||
|
model_input_name=bart_model.main_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
# Greedy
|
# Greedy
|
||||||
input_ids = input_ids.expand(6, -1)
|
input_ids = input_ids.expand(6, -1)
|
||||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||||
input_ids.shape[0],
|
batch_size=input_ids.shape[0],
|
||||||
|
model_input_name=bart_model.main_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||||
bos_token_id=bart_model.config.bos_token_id,
|
bos_token_id=bart_model.config.bos_token_id,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user