From 9dfd6a4baa0dde80db0a206b84c15c8dea1164c1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 13 Apr 2023 19:51:13 +0100 Subject: [PATCH] Generate: handle text conditioning with multimodal encoder-decoder models (#22748) --- docs/source/en/model_doc/pix2struct.mdx | 2 +- src/transformers/generation/tf_utils.py | 45 ++++++++++++---- src/transformers/generation/utils.py | 51 ++++++++++++++----- .../models/pix2struct/modeling_pix2struct.py | 29 ----------- tests/generation/test_framework_agnostic.py | 32 +++++++++++- tests/generation/test_utils.py | 30 +++++++---- 6 files changed, 123 insertions(+), 66 deletions(-) diff --git a/docs/source/en/model_doc/pix2struct.mdx b/docs/source/en/model_doc/pix2struct.mdx index fb4ecf05e0..c6d3136285 100644 --- a/docs/source/en/model_doc/pix2struct.mdx +++ b/docs/source/en/model_doc/pix2struct.mdx @@ -69,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str ## Pix2StructForConditionalGeneration [[autodoc]] Pix2StructForConditionalGeneration - - forward \ No newline at end of file + - forward diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 749c07d547..cc95cb31a4 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -837,12 +837,12 @@ class TFGenerationMixin: # 6. Prepare model inputs which will be used for auto-regressive generation if self.config.is_encoder_decoder: - # if encoder-decoder then `input_ids` come from `decoder_start_token_id` - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id, - model_kwargs=model_kwargs, ) else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") @@ -1095,16 +1095,41 @@ class TFGenerationMixin: def _prepare_decoder_input_ids_for_generation( self, batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, tf.Tensor], decoder_start_token_id: int = None, bos_token_id: int = None, - model_kwargs: Optional[Dict[str, tf.Tensor]] = None, - ) -> tf.Tensor: - # prepare `input_ids` for decoder if model is encoder-decoder + ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - return model_kwargs.pop("decoder_input_ids") + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") else: - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - return tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + decoder_input_ids_start = tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif tf.reduce_all(decoder_input_ids[:, 0] != decoder_start_token_id): + decoder_input_ids = tf.concat([decoder_input_ids_start, decoder_input_ids], axis=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = tf.concat( + (tf.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + axis=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + return decoder_input_ids, model_kwargs def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: # retrieve decoder_start_token_id for encoder-decoder models diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ae12ae2930..1200fbe5d9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -642,18 +642,44 @@ class GenerationMixin: def _prepare_decoder_input_ids_for_generation( self, batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], decoder_start_token_id: int = None, bos_token_id: int = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, device: torch.device = None, - ) -> torch.LongTensor: + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - return model_kwargs.pop("decoder_input_ids") + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") else: - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - if device is None: - device = self.device - return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + return decoder_input_ids, model_kwargs def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: decoder_start_token_id = ( @@ -1289,17 +1315,14 @@ class GenerationMixin: # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id, - model_kwargs=model_kwargs, device=inputs_tensor.device, ) - - # conditional generation for multi-modal models. - if "input_ids" in model_kwargs and model_input_name == "pixel_values": - input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1) else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 6ce279f027..aec56466a0 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1776,35 +1776,6 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): encoder_outputs=None, **kwargs, ): - if isinstance(input_ids, torch.Tensor): - # check if the first element of `input_ids` is equal to `input_ids`: - if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item(): - # add `input_ids` as first token to `input_ids` - input_ids = torch.cat( - [ - torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) - * self.config.decoder_start_token_id, - input_ids, - ], - dim=-1, - ) - - if decoder_attention_mask is not None: - decoder_attention_mask = torch.cat( - [ - torch.ones( - (decoder_attention_mask.shape[0], 1), - dtype=torch.long, - device=decoder_attention_mask.device, - ), - decoder_attention_mask, - ], - dim=-1, - ) - elif input_ids is None: - batch_size = flattened_patches.shape[0] - input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device) - if decoder_attention_mask is None: decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) diff --git a/tests/generation/test_framework_agnostic.py b/tests/generation/test_framework_agnostic.py index 72f0b5dc14..61845aa9bc 100644 --- a/tests/generation/test_framework_agnostic.py +++ b/tests/generation/test_framework_agnostic.py @@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin: # Decoder only call outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens) - # 29 + 3 new tokens - self.assertEqual(list(outputs.shape), [1, 32]) + # 1 BOS + 29 (input length) + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 33]) # Encoder decoder call > 20 outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20) @@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin: [token == model.config.pad_token_id for token in generated_tokens[0][expectation:]] ) self.assertTrue(unpadded_correct_condition or padded_correct_condition) + + def test_generate_vision2text_conditioning(self): + model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"] + floats_tensor = self.framework_dependent_parameters["floats_tensor"] + create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"] + is_pt = not model_cls.__name__.startswith("TF") + + pixel_values = floats_tensor((2, 3, 30, 30)) + conditioning_input = create_tensor_fn([[10], [10]]) # this should be the 2nd output token, after the BOS token + model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2") + if is_pt: + pixel_values = pixel_values.to(torch_device) + model = model.to(torch_device) + conditioning_input = conditioning_input.to(torch_device) + + # we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as + # decoder_input_ids, if the encoder is not a model with text input) + output_sequences_decoder_input_ids = model.generate( + pixel_values, max_length=5, decoder_input_ids=conditioning_input + ) + output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input) + if is_pt: + output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy() + output_sequences_input_ids = output_sequences_input_ids.cpu().numpy() + conditioning_input = conditioning_input.cpu().numpy() + + self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids)) + self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input)) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c0278f6ae4..dffaba4fb6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi max_length = 20 input_ids = input_ids.expand(2, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids.shape[0], + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi max_length = 20 input_ids = input_ids.expand(2, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids.shape[0], + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi input_ids = input_ids.expand(2, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids.shape[0], + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi input_ids = input_ids.expand(6, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids.shape[0], + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) @@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi # Greedy input_ids = input_ids.expand(6, -1) model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids = bart_model._prepare_decoder_input_ids_for_generation( - input_ids.shape[0], + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, )