Prepend bos token to Blip generations (#29642)
* prepend "bos" to blip generation * minor changes * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/instructblip/modeling_instructblip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add generation tester mixin --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ee38fc31fb
commit
b469ebc5cf
@@ -32,6 +32,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
@@ -434,7 +435,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase):
|
||||
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
@@ -683,7 +684,7 @@ class Blip2ModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
@@ -869,7 +870,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
prompt = "Question: which city is this? Answer:"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
# max_length for BLIP includes prompt length from now on, use max_new_tokens
|
||||
predictions = model.generate(**inputs, max_new_tokens=11)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
|
||||
@@ -39,6 +39,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
@@ -452,7 +453,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase):
|
||||
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
|
||||
Reference in New Issue
Block a user