[Blip2] Add Blip2Model (#21817)
* add v1 * add `Blip2Model` - add relevant functions - add tests - add on automapping * fix docs * fix doctest
This commit is contained in:
@@ -71,6 +71,14 @@ If you're interested in submitting a resource to be included here, please feel f
|
|||||||
[[autodoc]] Blip2QFormerModel
|
[[autodoc]] Blip2QFormerModel
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## Blip2Model
|
||||||
|
|
||||||
|
[[autodoc]] Blip2Model
|
||||||
|
- forward
|
||||||
|
- get_text_features
|
||||||
|
- get_image_features
|
||||||
|
- get_qformer_features
|
||||||
|
|
||||||
## Blip2ForConditionalGeneration
|
## Blip2ForConditionalGeneration
|
||||||
|
|
||||||
[[autodoc]] Blip2ForConditionalGeneration
|
[[autodoc]] Blip2ForConditionalGeneration
|
||||||
|
|||||||
@@ -1191,6 +1191,7 @@ else:
|
|||||||
[
|
[
|
||||||
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"Blip2ForConditionalGeneration",
|
"Blip2ForConditionalGeneration",
|
||||||
|
"Blip2Model",
|
||||||
"Blip2PreTrainedModel",
|
"Blip2PreTrainedModel",
|
||||||
"Blip2QFormerModel",
|
"Blip2QFormerModel",
|
||||||
"Blip2VisionModel",
|
"Blip2VisionModel",
|
||||||
@@ -4651,6 +4652,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.blip_2 import (
|
from .models.blip_2 import (
|
||||||
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
Blip2ForConditionalGeneration,
|
Blip2ForConditionalGeneration,
|
||||||
|
Blip2Model,
|
||||||
Blip2PreTrainedModel,
|
Blip2PreTrainedModel,
|
||||||
Blip2QFormerModel,
|
Blip2QFormerModel,
|
||||||
Blip2VisionModel,
|
Blip2VisionModel,
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("blenderbot", "BlenderbotModel"),
|
("blenderbot", "BlenderbotModel"),
|
||||||
("blenderbot-small", "BlenderbotSmallModel"),
|
("blenderbot-small", "BlenderbotSmallModel"),
|
||||||
("blip", "BlipModel"),
|
("blip", "BlipModel"),
|
||||||
|
("blip_2", "Blip2Model"),
|
||||||
("bloom", "BloomModel"),
|
("bloom", "BloomModel"),
|
||||||
("bridgetower", "BridgeTowerModel"),
|
("bridgetower", "BridgeTowerModel"),
|
||||||
("camembert", "CamembertModel"),
|
("camembert", "CamembertModel"),
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_blip_2"] = [
|
_import_structure["modeling_blip_2"] = [
|
||||||
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Blip2Model",
|
||||||
"Blip2QFormerModel",
|
"Blip2QFormerModel",
|
||||||
"Blip2PreTrainedModel",
|
"Blip2PreTrainedModel",
|
||||||
"Blip2ForConditionalGeneration",
|
"Blip2ForConditionalGeneration",
|
||||||
@@ -58,6 +59,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_blip_2 import (
|
from .modeling_blip_2 import (
|
||||||
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
Blip2ForConditionalGeneration,
|
Blip2ForConditionalGeneration,
|
||||||
|
Blip2Model,
|
||||||
Blip2PreTrainedModel,
|
Blip2PreTrainedModel,
|
||||||
Blip2QFormerModel,
|
Blip2QFormerModel,
|
||||||
Blip2VisionModel,
|
Blip2VisionModel,
|
||||||
|
|||||||
@@ -342,6 +342,43 @@ BLIP_2_VISION_INPUTS_DOCSTRING = r"""
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
BLIP_2_TEXT_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||||
|
Indices of decoder input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||||
|
|
||||||
|
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
||||||
|
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
||||||
|
|
||||||
|
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
|
||||||
|
Training](./t5#training).
|
||||||
|
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||||
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||||
|
be used by default.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
BLIP_2_INPUTS_DOCSTRING = r"""
|
BLIP_2_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
@@ -1171,6 +1208,337 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
|
||||||
|
(Q-Former) and a language model.
|
||||||
|
""",
|
||||||
|
BLIP_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Blip2Model(Blip2PreTrainedModel):
|
||||||
|
config_class = Blip2Config
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
|
||||||
|
def __init__(self, config: Blip2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||||
|
|
||||||
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||||
|
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||||
|
|
||||||
|
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
||||||
|
if config.use_decoder_only_language_model:
|
||||||
|
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||||
|
else:
|
||||||
|
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||||
|
self.language_model = language_model
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.vision_model.embeddings.patch_embedding
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING)
|
||||||
|
def get_text_features(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
decoder_input_ids: Optional[torch.Tensor] = None,
|
||||||
|
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
|
||||||
|
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
|
||||||
|
contains the language model logits, the past key values and the hidden states if
|
||||||
|
`output_hidden_states=True`.
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoTokenizer, Blip2Model
|
||||||
|
|
||||||
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||||
|
|
||||||
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||||
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").to(device)
|
||||||
|
>>> text_features = model.get_text_features(**inputs)
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.config.use_decoder_only_language_model:
|
||||||
|
text_outputs = self.language_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
text_outputs = self.language_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
return text_outputs
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING)
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
|
||||||
|
The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
|
||||||
|
contains the image features, the pooled image features and the hidden states if
|
||||||
|
`output_hidden_states=True`.
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Blip2Model
|
||||||
|
|
||||||
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||||
|
|
||||||
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
||||||
|
>>> image_outputs = model.get_image_features(**inputs)
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
return vision_outputs
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING)
|
||||||
|
def get_qformer_features(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
|
||||||
|
The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
|
||||||
|
contains the image features, the pooled image features and the hidden states if
|
||||||
|
`output_hidden_states=True`.
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import Blip2Processor, Blip2Model
|
||||||
|
|
||||||
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||||
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||||
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
||||||
|
>>> qformer_outputs = model.get_qformer_features(**inputs)
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
|
||||||
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
return query_outputs
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=Blip2ForConditionalGenerationModelOutput, config_class=Blip2VisionConfig)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
input_ids: torch.FloatTensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import Blip2Processor, Blip2Model
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||||
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||||
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> prompt = "Question: how many cats are there? Answer:"
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# step 1: forward the images through the vision encoder,
|
||||||
|
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
|
||||||
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
query_output = query_outputs[0]
|
||||||
|
|
||||||
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
|
language_model_inputs = self.language_projection(query_output)
|
||||||
|
language_model_attention_mask = torch.ones(
|
||||||
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
|
)
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||||
|
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
expected_device = language_model_attention_mask.device
|
||||||
|
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)
|
||||||
|
|
||||||
|
if self.config.use_decoder_only_language_model:
|
||||||
|
outputs = self.language_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
logits = outputs.logits if return_dict else outputs[0]
|
||||||
|
loss = None
|
||||||
|
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||||
|
if labels is not None:
|
||||||
|
logits = logits[:, -labels.size(1) :, :]
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
||||||
|
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(reduction="mean")
|
||||||
|
|
||||||
|
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
||||||
|
else:
|
||||||
|
outputs = self.language_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
loss = outputs.loss if return_dict else outputs[0]
|
||||||
|
logits = outputs.logits if return_dict else outputs[1]
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits, vision_outputs, query_outputs, outputs)
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return Blip2ForConditionalGenerationModelOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
vision_outputs=vision_outputs,
|
||||||
|
qformer_outputs=query_outputs,
|
||||||
|
language_model_outputs=outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision
|
BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||||
|
|||||||
@@ -1221,6 +1221,13 @@ class Blip2ForConditionalGeneration(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2Model(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Blip2PreTrainedModel(metaclass=DummyObject):
|
class Blip2PreTrainedModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import Blip2ForConditionalGeneration, Blip2VisionModel
|
from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel
|
||||||
from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
@@ -664,8 +664,8 @@ class Blip2ForConditionalGenerationModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
|
class Blip2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -737,6 +737,56 @@ class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = Blip2ForConditionalGeneration.from_pretrained(model_name)
|
model = Blip2ForConditionalGeneration.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_get_text_features(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
|
||||||
|
"attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device),
|
||||||
|
"decoder_input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
model = Blip2Model(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
text_features = model.get_text_features(**inputs_dict)
|
||||||
|
self.assertEqual(text_features[0].shape, (1, 10, config.text_config.vocab_size))
|
||||||
|
|
||||||
|
def test_get_image_features(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
||||||
|
|
||||||
|
for key in keys_to_pop:
|
||||||
|
inputs_dict.pop(key)
|
||||||
|
|
||||||
|
model = Blip2Model(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
image_features = model.get_image_features(**inputs_dict)
|
||||||
|
self.assertEqual(
|
||||||
|
image_features[0].shape,
|
||||||
|
(
|
||||||
|
self.model_tester.vision_model_tester.batch_size,
|
||||||
|
self.model_tester.vision_model_tester.seq_length,
|
||||||
|
config.vision_config.hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_qformer_features(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
||||||
|
|
||||||
|
for key in keys_to_pop:
|
||||||
|
inputs_dict.pop(key)
|
||||||
|
|
||||||
|
model = Blip2Model(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
qformer_features = model.get_qformer_features(**inputs_dict)
|
||||||
|
self.assertEqual(
|
||||||
|
qformer_features[0].shape,
|
||||||
|
(self.model_tester.vision_model_tester.batch_size, 10, config.vision_config.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
# override from common to deal with nested configurations (`vision_config`, `text_config` and `qformer_config`)
|
# override from common to deal with nested configurations (`vision_config`, `text_config` and `qformer_config`)
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user