[WIP] Add BridgeTowerForContrastiveLearning (#21964)
* Add BridgeTower for ITC * Fix review feedback * Rename BridgeTowerForITC, cleanup * Fix style and quality * implement tests --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com>
This commit is contained in:
committed by
GitHub
parent
edea08a6b0
commit
de81adf978
@@ -42,6 +42,28 @@ In principle, one can apply any visual, textual or cross-modal encoder in the pr
|
|||||||
The [`BridgeTowerProcessor`] wraps [`RobertaTokenizer`] and [`BridgeTowerImageProcessor`] into a single instance to both
|
The [`BridgeTowerProcessor`] wraps [`RobertaTokenizer`] and [`BridgeTowerImageProcessor`] into a single instance to both
|
||||||
encode the text and prepare the images respectively.
|
encode the text and prepare the images respectively.
|
||||||
|
|
||||||
|
The following example shows how to run contrastive learning using [`BridgeTowerProcessor`] and [`BridgeTowerForContrastiveLearning`].
|
||||||
|
```python
|
||||||
|
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
|
||||||
|
>>> import requests
|
||||||
|
>>> from PIL import Image
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
|
||||||
|
|
||||||
|
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
|
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
|
|
||||||
|
>>> # forward pass
|
||||||
|
>>> scores = dict()
|
||||||
|
>>> for text in texts:
|
||||||
|
... # prepare inputs
|
||||||
|
... encoding = processor(image, text, return_tensors="pt")
|
||||||
|
... outputs = model(**encoding)
|
||||||
|
... scores[text] = outputs
|
||||||
|
```
|
||||||
|
|
||||||
The following example shows how to run image-text retrieval using [`BridgeTowerProcessor`] and [`BridgeTowerForImageAndTextRetrieval`].
|
The following example shows how to run image-text retrieval using [`BridgeTowerProcessor`] and [`BridgeTowerForImageAndTextRetrieval`].
|
||||||
```python
|
```python
|
||||||
>>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
|
>>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
|
||||||
@@ -128,6 +150,11 @@ Tips:
|
|||||||
[[autodoc]] BridgeTowerModel
|
[[autodoc]] BridgeTowerModel
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## BridgeTowerForContrastiveLearning
|
||||||
|
|
||||||
|
[[autodoc]] BridgeTowerForContrastiveLearning
|
||||||
|
- forward
|
||||||
|
|
||||||
## BridgeTowerForMaskedLM
|
## BridgeTowerForMaskedLM
|
||||||
|
|
||||||
[[autodoc]] BridgeTowerForMaskedLM
|
[[autodoc]] BridgeTowerForMaskedLM
|
||||||
|
|||||||
@@ -1182,6 +1182,7 @@ else:
|
|||||||
_import_structure["models.bridgetower"].extend(
|
_import_structure["models.bridgetower"].extend(
|
||||||
[
|
[
|
||||||
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BridgeTowerForContrastiveLearning",
|
||||||
"BridgeTowerForImageAndTextRetrieval",
|
"BridgeTowerForImageAndTextRetrieval",
|
||||||
"BridgeTowerForMaskedLM",
|
"BridgeTowerForMaskedLM",
|
||||||
"BridgeTowerModel",
|
"BridgeTowerModel",
|
||||||
@@ -4666,6 +4667,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.bridgetower import (
|
from .models.bridgetower import (
|
||||||
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BridgeTowerForContrastiveLearning,
|
||||||
BridgeTowerForImageAndTextRetrieval,
|
BridgeTowerForImageAndTextRetrieval,
|
||||||
BridgeTowerForMaskedLM,
|
BridgeTowerForMaskedLM,
|
||||||
BridgeTowerModel,
|
BridgeTowerModel,
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_bridgetower"] = [
|
_import_structure["modeling_bridgetower"] = [
|
||||||
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BridgeTowerForContrastiveLearning",
|
||||||
"BridgeTowerForImageAndTextRetrieval",
|
"BridgeTowerForImageAndTextRetrieval",
|
||||||
"BridgeTowerForMaskedLM",
|
"BridgeTowerForMaskedLM",
|
||||||
"BridgeTowerModel",
|
"BridgeTowerModel",
|
||||||
@@ -74,6 +75,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_bridgetower import (
|
from .modeling_bridgetower import (
|
||||||
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BridgeTowerForContrastiveLearning,
|
||||||
BridgeTowerForImageAndTextRetrieval,
|
BridgeTowerForImageAndTextRetrieval,
|
||||||
BridgeTowerForMaskedLM,
|
BridgeTowerForMaskedLM,
|
||||||
BridgeTowerModel,
|
BridgeTowerModel,
|
||||||
|
|||||||
@@ -143,9 +143,8 @@ class BridgeTowerModelOutput(ModelOutput):
|
|||||||
token), respectively, after further processing through layers used for auxiliary pretraining tasks.
|
token), respectively, after further processing through layers used for auxiliary pretraining tasks.
|
||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
||||||
|
the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
||||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
sequence_length)`.
|
sequence_length)`.
|
||||||
@@ -161,6 +160,40 @@ class BridgeTowerModelOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BridgeTowerContrastiveOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output type of ['BridgeTowerForContrastiveLearning']
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
||||||
|
The text embeddings obtained by applying the projection layer to the pooler_output.
|
||||||
|
image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
||||||
|
The image embeddings obtained by applying the projection layer to the pooler_output.
|
||||||
|
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
||||||
|
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||||
|
Image-text contrastive loss.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
||||||
|
the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
class BridgeTowerResidualAttention(nn.Module):
|
class BridgeTowerResidualAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1314,7 +1347,12 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states_text += (text_embeds,)
|
all_hidden_states_text += (text_embeds,)
|
||||||
|
|
||||||
|
if image_embeds is None:
|
||||||
image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype))
|
image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype))
|
||||||
|
else:
|
||||||
|
# Permute as BridgeTowerResidualAttention has batch_first=True
|
||||||
|
image_embeds = image_embeds.permute(1, 0, 2)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states_image += (image_embeds,)
|
all_hidden_states_image += (image_embeds,)
|
||||||
|
|
||||||
@@ -1438,7 +1476,11 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
|
|||||||
all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross)
|
all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [text_features, image_features, cls_features] if v is not None)
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BridgeTowerModelOutput(
|
return BridgeTowerModelOutput(
|
||||||
text_features=text_features,
|
text_features=text_features,
|
||||||
@@ -1700,3 +1742,138 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BridgeTowerContrastiveHead(nn.Module):
|
||||||
|
def __init__(self, hidden_size, embed_size):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = nn.Linear(hidden_size, embed_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
|
||||||
|
""",
|
||||||
|
BRIDGETOWER_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.bridgetower = BridgeTowerModel(config)
|
||||||
|
|
||||||
|
self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
|
||||||
|
self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
|
||||||
|
self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)
|
||||||
|
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=BridgeTowerContrastiveOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
pixel_mask: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
image_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = True,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
|
||||||
|
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
|
||||||
|
The pairs with 0 will be skipped for calculation.
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
|
||||||
|
>>> import requests
|
||||||
|
>>> from PIL import Image
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> texts = "An image of two cats chilling on a couch"
|
||||||
|
|
||||||
|
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
|
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
|
>>> outputs = model(**inputs, output_hidden_states=True)
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.bridgetower(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_mask=pixel_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
image_embeds=image_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooler_output = outputs.pooler_output if return_dict else outputs[2]
|
||||||
|
hidden_states_txt, hidden_states_img, hidden_states_cross_modal = (
|
||||||
|
outputs.hidden_states if return_dict else outputs[3]
|
||||||
|
)
|
||||||
|
|
||||||
|
text_embeds = hidden_states_txt[-1]
|
||||||
|
image_embeds = hidden_states_img[-1]
|
||||||
|
|
||||||
|
image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
|
||||||
|
image_token_type_embeddings = self.bridgetower.token_type_embeddings(
|
||||||
|
torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
|
||||||
|
).expand_as(image_embeds_with_ln)
|
||||||
|
|
||||||
|
image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings
|
||||||
|
|
||||||
|
# normalized features
|
||||||
|
text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
|
||||||
|
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2)
|
||||||
|
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)
|
||||||
|
|
||||||
|
logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
|
||||||
|
|
||||||
|
logit_scale = self.logit_scale.exp()
|
||||||
|
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||||
|
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
|
||||||
|
logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
|
||||||
|
|
||||||
|
itc_loss = None
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
labels = torch.arange(len(labels), device=logits.device)
|
||||||
|
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
|
||||||
|
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
|
||||||
|
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
|
||||||
|
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = tuple(logits)
|
||||||
|
return ((itc_loss,) + output) if itc_loss is not None else output
|
||||||
|
|
||||||
|
return BridgeTowerContrastiveOutput(
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
text_embeds=text_embeds,
|
||||||
|
image_embeds=image_embeds,
|
||||||
|
cross_embeds=cross_embeds,
|
||||||
|
logits=logits,
|
||||||
|
loss=itc_loss,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1328,6 +1328,13 @@ class BloomPreTrainedModel(metaclass=DummyObject):
|
|||||||
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class BridgeTowerForContrastiveLearning(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class BridgeTowerForImageAndTextRetrieval(metaclass=DummyObject):
|
class BridgeTowerForImageAndTextRetrieval(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -24,14 +24,25 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
|
|||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
from ...test_modeling_common import (
|
||||||
|
ModelTesterMixin,
|
||||||
|
_config_zero_init,
|
||||||
|
floats_tensor,
|
||||||
|
ids_tensor,
|
||||||
|
random_attention_mask,
|
||||||
|
)
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerModel
|
from transformers import (
|
||||||
|
BridgeTowerForContrastiveLearning,
|
||||||
|
BridgeTowerForImageAndTextRetrieval,
|
||||||
|
BridgeTowerForMaskedLM,
|
||||||
|
BridgeTowerModel,
|
||||||
|
)
|
||||||
from transformers.models.bridgetower.modeling_bridgetower import BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.bridgetower.modeling_bridgetower import BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_10
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_10
|
||||||
else:
|
else:
|
||||||
@@ -65,6 +76,8 @@ class BridgeTowerModelTester:
|
|||||||
text_config=None,
|
text_config=None,
|
||||||
vision_config=None,
|
vision_config=None,
|
||||||
image_size=288,
|
image_size=288,
|
||||||
|
contrastive_hidden_size=512,
|
||||||
|
logit_scale_init_value=2.6592,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
|
self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
|
||||||
@@ -90,6 +103,8 @@ class BridgeTowerModelTester:
|
|||||||
self.is_training = False
|
self.is_training = False
|
||||||
self.expected_num_hidden_layers = 32
|
self.expected_num_hidden_layers = 32
|
||||||
self.output_hidden_states = output_hidden_states
|
self.output_hidden_states = output_hidden_states
|
||||||
|
self.contrastive_hidden_size = contrastive_hidden_size
|
||||||
|
self.logit_scale_init_value = logit_scale_init_value
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -118,6 +133,8 @@ class BridgeTowerModelTester:
|
|||||||
init_layernorm_from_vision_encoder=self.init_layernorm_from_vision_encoder,
|
init_layernorm_from_vision_encoder=self.init_layernorm_from_vision_encoder,
|
||||||
num_channels=self.num_channels,
|
num_channels=self.num_channels,
|
||||||
output_hidden_states=self.output_hidden_states,
|
output_hidden_states=self.output_hidden_states,
|
||||||
|
contrastive_hidden_size=self.contrastive_hidden_size,
|
||||||
|
logit_scale_init_value=self.logit_scale_init_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(
|
def create_and_check_model(
|
||||||
@@ -189,7 +206,14 @@ class BridgeTowerModelTester:
|
|||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||||
class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(BridgeTowerModel, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
(
|
||||||
|
BridgeTowerModel,
|
||||||
|
BridgeTowerForImageAndTextRetrieval,
|
||||||
|
BridgeTowerForMaskedLM,
|
||||||
|
BridgeTowerForContrastiveLearning,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = {"feature-extraction": BridgeTowerModel} if is_torch_available() else {}
|
pipeline_model_mapping = {"feature-extraction": BridgeTowerModel} if is_torch_available() else {}
|
||||||
|
|
||||||
@@ -347,6 +371,29 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
if self.has_attentions:
|
if self.has_attentions:
|
||||||
self.assertIsNotNone(attentions.grad)
|
self.assertIsNotNone(attentions.grad)
|
||||||
|
|
||||||
|
# override as the `logit_scale` parameter initilization is different for BRIDGE TOWER
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if name == "logit_scale":
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
param.data.item(),
|
||||||
|
config.logit_scale_init_value,
|
||||||
|
delta=1e-3,
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="""Bridge Tower does not have input/output embeddings. So this test is not applicable.""")
|
@unittest.skip(reason="""Bridge Tower does not have input/output embeddings. So this test is not applicable.""")
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
@@ -429,12 +476,31 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
|
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_constrastive_learning(self):
|
||||||
|
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
|
image = prepare_img()
|
||||||
|
text = "a bunch of cats laying on a tower."
|
||||||
|
inputs = processor(image, text, return_tensors="pt").to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs, output_hidden_states=True)
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
expected_shape = torch.Size([1, 3, 512])
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||||
class BridgeTowerModelTrainingTest(unittest.TestCase):
|
class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||||
all_training_supported_model_classes = (
|
all_training_supported_model_classes = (
|
||||||
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerForContrastiveLearning)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -445,7 +511,7 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if model_class == BridgeTowerForMaskedLM:
|
if model_class == BridgeTowerForMaskedLM:
|
||||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning:
|
||||||
inputs_dict["labels"] = ids_tensor([1], 2)
|
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||||||
"Swin2SRForImageSuperResolution",
|
"Swin2SRForImageSuperResolution",
|
||||||
"BridgeTowerForImageAndTextRetrieval",
|
"BridgeTowerForImageAndTextRetrieval",
|
||||||
"BridgeTowerForMaskedLM",
|
"BridgeTowerForMaskedLM",
|
||||||
|
"BridgeTowerForContrastiveLearning",
|
||||||
"CLIPSegForImageSegmentation",
|
"CLIPSegForImageSegmentation",
|
||||||
"CLIPSegVisionModel",
|
"CLIPSegVisionModel",
|
||||||
"CLIPSegTextModel",
|
"CLIPSegTextModel",
|
||||||
|
|||||||
Reference in New Issue
Block a user