BLIP: fix generation after hub update (#34876)

* fix blip generation

* dont remove it yet

* Update src/transformers/models/blip_2/modeling_blip_2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* address comments

* modular

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-11-25 10:41:55 +01:00
committed by GitHub
parent c1a8520419
commit 098962dac2
7 changed files with 42 additions and 35 deletions

View File

@@ -421,7 +421,12 @@ class GenerationMixin:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values is not None:
model_input = model_input[:, -input_ids.shape[1] :]
current_input_length = (
model_inputs["inputs_embeds"].shape[1]
if model_inputs["inputs_embeds"] is not None
else model_inputs[input_ids_key].shape[1]
)
model_input = model_input[:, -current_input_length:]
model_input = model_input.clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input

View File

@@ -2307,12 +2307,14 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
if input_ids is None:
input_ids = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

View File

@@ -1591,11 +1591,12 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
)
if input_ids is None:
input_ids = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

View File

@@ -1626,11 +1626,12 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
)
if input_ids is None:
input_ids = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

View File

@@ -439,11 +439,12 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
)
if input_ids is None:
input_ids = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

View File

@@ -1994,8 +1994,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
print(predictions[0].tolist(), generated_text)
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual("a woman sitting on the beach with a dog", generated_text)
# image and context
@@ -2007,10 +2007,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
self.assertEqual(
predictions[0].tolist(),
[2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118],
)
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach")
def test_inference_interpolate_pos_encoding(self):
@@ -2026,7 +2024,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118])
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 8, 2335, 15, 5, 4105, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual(generated_text, "a woman and dog on the beach")
def test_inference_opt_batched_beam_search(self):
@@ -2042,8 +2041,9 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
predictions = model.generate(**inputs, num_beams=2)
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118] # fmt: skip
self.assertEqual(predictions[0].tolist(), expected_ids)
self.assertEqual(predictions[1].tolist(), expected_ids)
def test_inference_t5(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
@@ -2070,10 +2070,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
self.assertEqual(
predictions[0].tolist(),
[0, 3, 7, 152, 67, 839, 1],
)
self.assertEqual(predictions[0].tolist(), [0, 3, 7, 152, 67, 839, 1])
self.assertEqual(generated_text, "san diego")
def test_inference_t5_batched_beam_search(self):

View File

@@ -945,7 +945,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
# Add args to the config to trigger new logic when inputs are expanded in processing file
processor.num_query_tokens = model.config.num_query_tokens
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
model.config.image_token_index = len(processor.tokenizer) - 1
model.config.image_token_index = len(processor.tokenizer) - 2
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
# Generate again with new inputs