Prepend bos token to Blip generations (#29642)
* prepend "bos" to blip generation * minor changes * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/instructblip/modeling_instructblip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add generation tester mixin --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ee38fc31fb
commit
b469ebc5cf
@@ -1828,8 +1828,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
|||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||||
|
|
||||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||||
|
# -1 is to account for the prepended BOS after `generate.`
|
||||||
|
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||||
|
|
||||||
outputs = self.language_model.generate(
|
outputs = self.language_model.generate(
|
||||||
@@ -1838,4 +1840,16 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
|||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# this is a temporary workaround to be consistent with other generation models and
|
||||||
|
# have BOS as the first token, even though under the hood we are calling LM with embeds
|
||||||
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
|
bos_tokens = (
|
||||||
|
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
||||||
|
.repeat(batch_size, 1)
|
||||||
|
.to(image_embeds.device)
|
||||||
|
)
|
||||||
|
if not isinstance(outputs, torch.Tensor):
|
||||||
|
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
|
||||||
|
else:
|
||||||
|
outputs = torch.cat([bos_tokens, outputs], dim=-1)
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -1538,8 +1538,9 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
|
|||||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||||
|
|
||||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||||
|
# -1 is to account for the prepended BOS after `generate.`
|
||||||
if not self.language_model.config.is_encoder_decoder:
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||||
|
|
||||||
outputs = self.language_model.generate(
|
outputs = self.language_model.generate(
|
||||||
@@ -1548,13 +1549,21 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
|
|||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
|
# this is a temporary workaround to be consistent with other generation models and
|
||||||
# with the tokenizer's bos token being set to </s> which has ID=2,
|
# have BOS as the first token, even though under the hood we are calling LM with embeds
|
||||||
# whereas the model's text config has bos token id = 0
|
if not self.language_model.config.is_encoder_decoder:
|
||||||
if self.config.text_config.architectures[0] == "LLaMAForCausalLM":
|
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
|
||||||
if isinstance(outputs, torch.Tensor):
|
# with the tokenizer's bos token being set to </s> which has ID=2,
|
||||||
outputs[outputs == 0] = 2
|
# whereas the model's text config has bos token id = 0
|
||||||
|
bos_token_id = (
|
||||||
|
2
|
||||||
|
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
|
||||||
|
else self.config.text_config.bos_token_id
|
||||||
|
)
|
||||||
|
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
|
||||||
|
if not isinstance(outputs, torch.Tensor):
|
||||||
|
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
|
||||||
else:
|
else:
|
||||||
outputs.sequences[outputs.sequences == 0] = 2
|
outputs = torch.cat([bos_tokens, outputs], dim=-1)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
ModelTesterMixin,
|
ModelTesterMixin,
|
||||||
@@ -434,7 +435,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase):
|
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
@@ -683,7 +684,7 @@ class Blip2ModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
@@ -869,7 +870,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
|||||||
prompt = "Question: which city is this? Answer:"
|
prompt = "Question: which city is this? Answer:"
|
||||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||||
|
|
||||||
predictions = model.generate(**inputs)
|
# max_length for BLIP includes prompt length from now on, use max_new_tokens
|
||||||
|
predictions = model.generate(**inputs, max_new_tokens=11)
|
||||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||||
|
|
||||||
# Test output
|
# Test output
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
ModelTesterMixin,
|
ModelTesterMixin,
|
||||||
@@ -452,7 +453,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase):
|
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|||||||
Reference in New Issue
Block a user