BLIP: fix generation after hub update (#34876)
* fix blip generation * dont remove it yet * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments * modular --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c1a8520419
commit
098962dac2
@@ -421,7 +421,12 @@ class GenerationMixin:
|
||||
model_input = kwargs.get(model_input_name)
|
||||
if model_input is not None:
|
||||
if past_key_values is not None:
|
||||
model_input = model_input[:, -input_ids.shape[1] :]
|
||||
current_input_length = (
|
||||
model_inputs["inputs_embeds"].shape[1]
|
||||
if model_inputs["inputs_embeds"] is not None
|
||||
else model_inputs[input_ids_key].shape[1]
|
||||
)
|
||||
model_input = model_input[:, -current_input_length:]
|
||||
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
||||
model_inputs[model_input_name] = model_input
|
||||
|
||||
|
||||
@@ -2307,12 +2307,14 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
language_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = (
|
||||
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
||||
.repeat(batch_size, 1)
|
||||
.to(image_embeds.device)
|
||||
)
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
@@ -1591,11 +1591,12 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = (
|
||||
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
||||
.repeat(batch_size, 1)
|
||||
.to(image_embeds.device)
|
||||
)
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
|
||||
@@ -1626,11 +1626,12 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = (
|
||||
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
||||
.repeat(batch_size, 1)
|
||||
.to(image_embeds.device)
|
||||
)
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
|
||||
@@ -439,11 +439,12 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = (
|
||||
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
||||
.repeat(batch_size, 1)
|
||||
.to(image_embeds.device)
|
||||
)
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
|
||||
@@ -1994,8 +1994,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
print(predictions[0].tolist(), generated_text)
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual("a woman sitting on the beach with a dog", generated_text)
|
||||
|
||||
# image and context
|
||||
@@ -2007,10 +2007,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
self.assertEqual(
|
||||
predictions[0].tolist(),
|
||||
[2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118],
|
||||
)
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach")
|
||||
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
@@ -2026,7 +2024,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
predictions = model.generate(**inputs, interpolate_pos_encoding=True)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118])
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 8, 2335, 15, 5, 4105, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual(generated_text, "a woman and dog on the beach")
|
||||
|
||||
def test_inference_opt_batched_beam_search(self):
|
||||
@@ -2042,8 +2041,9 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
# Test output (in this case, slightly different from greedy search)
|
||||
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118])
|
||||
expected_ids = [50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 50265, 2, 102, 693, 2828, 15, 5, 4105, 19, 69, 2335, 50118] # fmt: skip
|
||||
self.assertEqual(predictions[0].tolist(), expected_ids)
|
||||
self.assertEqual(predictions[1].tolist(), expected_ids)
|
||||
|
||||
def test_inference_t5(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
@@ -2070,10 +2070,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
self.assertEqual(
|
||||
predictions[0].tolist(),
|
||||
[0, 3, 7, 152, 67, 839, 1],
|
||||
)
|
||||
self.assertEqual(predictions[0].tolist(), [0, 3, 7, 152, 67, 839, 1])
|
||||
self.assertEqual(generated_text, "san diego")
|
||||
|
||||
def test_inference_t5_batched_beam_search(self):
|
||||
|
||||
@@ -945,7 +945,7 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
||||
processor.num_query_tokens = model.config.num_query_tokens
|
||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
model.config.image_token_index = len(processor.tokenizer) - 1
|
||||
model.config.image_token_index = len(processor.tokenizer) - 2
|
||||
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
|
||||
|
||||
# Generate again with new inputs
|
||||
|
||||
Reference in New Issue
Block a user