From 098962dac29d25b6fc3c3adb8554d83d9d650376 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 25 Nov 2024 10:41:55 +0100 Subject: [PATCH] BLIP: fix generation after hub update (#34876) * fix blip generation * dont remove it yet * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments * modular --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 7 +++++- .../models/blip_2/modeling_blip_2.py | 12 ++++++---- .../instructblip/modeling_instructblip.py | 11 +++++---- .../modeling_instructblipvideo.py | 11 +++++---- .../modular_instructblipvideo.py | 11 +++++---- tests/models/blip_2/test_modeling_blip_2.py | 23 ++++++++----------- .../test_modeling_instructblip.py | 2 +- 7 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c839a6538d..16b26ade7a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -421,7 +421,12 @@ class GenerationMixin: model_input = kwargs.get(model_input_name) if model_input is not None: if past_key_values is not None: - model_input = model_input[:, -input_ids.shape[1] :] + current_input_length = ( + model_inputs["inputs_embeds"].shape[1] + if model_inputs["inputs_embeds"] is not None + else model_inputs[input_ids_key].shape[1] + ) + model_input = model_input[:, -current_input_length:] model_input = model_input.clone(memory_format=torch.contiguous_format) model_inputs[model_input_name] = model_input diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index d34528b743..2e32912421 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2307,12 +2307,14 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): language_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) + if input_ids is None: - input_ids = ( - torch.LongTensor([[self.config.text_config.bos_token_id]]) - .repeat(batch_size, 1) - .to(image_embeds.device) - ) + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "image_token_index", None) is not None: + start_tokens += [self.config.image_token_index] * self.config.num_query_tokens + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index e5622185bc..a63393ab1d 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1591,11 +1591,12 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati ) if input_ids is None: - input_ids = ( - torch.LongTensor([[self.config.text_config.bos_token_id]]) - .repeat(batch_size, 1) - .to(image_embeds.device) - ) + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "image_token_index", None) is not None: + start_tokens += [self.config.image_token_index] * self.config.num_query_tokens + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = input_ids.repeat(batch_size, 1) + if attention_mask is None: attention_mask = torch.ones_like(input_ids) diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index b0a494dcfe..e922d1e3f2 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1626,11 +1626,12 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel ) if input_ids is None: - input_ids = ( - torch.LongTensor([[self.config.text_config.bos_token_id]]) - .repeat(batch_size, 1) - .to(image_embeds.device) - ) + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "video_token_index", None) is not None: + start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4 + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = input_ids.repeat(batch_size, 1) + if attention_mask is None: attention_mask = torch.ones_like(input_ids) diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index b0dc8a2157..126d81b6d3 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -439,11 +439,12 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera ) if input_ids is None: - input_ids = ( - torch.LongTensor([[self.config.text_config.bos_token_id]]) - .repeat(batch_size, 1) - .to(image_embeds.device) - ) + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "video_token_index", None) is not None: + start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4 + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = input_ids.repeat(batch_size, 1) + if attention_mask is None: attention_mask = torch.ones_like(input_ids) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index a141ef40be..a1ea708efd 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -1994,8 +1994,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase): generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() # Test output - print(predictions[0].tolist(), generated_text) - self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]) + expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118] # fmt: skip + self.assertEqual(predictions[0].tolist(), expected_ids) self.assertEqual("a woman sitting on the beach with a dog", generated_text) # image and context @@ -2007,10 +2007,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase): generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() # Test output - self.assertEqual( - predictions[0].tolist(), - [2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118], - ) + expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118] # fmt: skip + self.assertEqual(predictions[0].tolist(), expected_ids) self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach") def test_inference_interpolate_pos_encoding(self): @@ -2026,7 +2024,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase): predictions = model.generate(**inputs, interpolate_pos_encoding=True) generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() - self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118]) + expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 8, 2335, 15, 5, 4105, 50118] # fmt: skip + self.assertEqual(predictions[0].tolist(), expected_ids) self.assertEqual(generated_text, "a woman and dog on the beach") def test_inference_opt_batched_beam_search(self): @@ -2042,8 +2041,9 @@ class Blip2ModelIntegrationTest(unittest.TestCase): predictions = model.generate(**inputs, num_beams=2) # Test output (in this case, slightly different from greedy search) - self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118]) - self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118]) + expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118] # fmt: skip + self.assertEqual(predictions[0].tolist(), expected_ids) + self.assertEqual(predictions[1].tolist(), expected_ids) def test_inference_t5(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") @@ -2070,10 +2070,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() # Test output - self.assertEqual( - predictions[0].tolist(), - [0, 3, 7, 152, 67, 839, 1], - ) + self.assertEqual(predictions[0].tolist(), [0, 3, 7, 152, 67, 839, 1]) self.assertEqual(generated_text, "san diego") def test_inference_t5_batched_beam_search(self): diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index e77577dad7..baacc12caa 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -945,7 +945,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase): # Add args to the config to trigger new logic when inputs are expanded in processing file processor.num_query_tokens = model.config.num_query_tokens processor.tokenizer.add_special_tokens({"additional_special_tokens": [""]}) - model.config.image_token_index = len(processor.tokenizer) - 1 + model.config.image_token_index = len(processor.tokenizer) - 2 model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64) # Generate again with new inputs