IDEFICS: allow interpolation of vision's pos embeddings (#26029)
* add pos embed interpolation for vision encoder * style * update config with interpolate_pos_encoding arg * fix imports formatting * take off copied from on vision embeddings * add test for image embeddings interpolation * add credit for interpolation code * Update src/transformers/models/idefics/configuration_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/idefics/vision.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix condition to check nbr image patches match shape of pos embeddings * use kwargs in the forward methods for interpolation * fix tests * have interpolate_pos_encoding default to False instead of None * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/idefics/configuration_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * take off for loop meant to print k,v * add interpolate_pos_encoding arg in prepare_inputs_for_generation * add test for interpolated generation * fix edge case num_patches == num_positions and height == width * add test for edge case * fix pos_embed in interpolate * allow interpolation in bf16 with upcasting * Update src/transformers/models/idefics/vision.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/idefics/vision.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add multiple images tests for interpolation and generation --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -236,6 +236,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|||||||
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
|
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
|
||||||
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
|
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
|
||||||
image_attention_mask = kwargs.get("image_attention_mask", None)
|
image_attention_mask = kwargs.get("image_attention_mask", None)
|
||||||
|
interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
@@ -248,6 +249,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|||||||
"image_encoder_embeddings": image_encoder_embeddings,
|
"image_encoder_embeddings": image_encoder_embeddings,
|
||||||
"perceiver_embeddings": perceiver_embeddings,
|
"perceiver_embeddings": perceiver_embeddings,
|
||||||
"image_attention_mask": image_attention_mask,
|
"image_attention_mask": image_attention_mask,
|
||||||
|
"interpolate_pos_encoding": interpolate_pos_encoding,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1157,6 +1159,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
|
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
@@ -1212,7 +1215,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
# Get sequence from the vision encoder
|
||||||
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
).last_hidden_state
|
||||||
|
|
||||||
elif image_encoder_embeddings is not None:
|
elif image_encoder_embeddings is not None:
|
||||||
batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size()
|
batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size()
|
||||||
@@ -1468,6 +1473,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
|
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1516,6 +1522,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
|
""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -24,10 +25,7 @@ from torch import nn
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||||
from ...utils import (
|
from ...utils import ModelOutput, logging
|
||||||
ModelOutput,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from .configuration_idefics import IdeficsVisionConfig
|
from .configuration_idefics import IdeficsVisionConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -63,7 +61,7 @@ class IdeficsVisionModelOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
|
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings
|
||||||
class IdeficsVisionEmbeddings(nn.Module):
|
class IdeficsVisionEmbeddings(nn.Module):
|
||||||
def __init__(self, config: IdeficsVisionConfig):
|
def __init__(self, config: IdeficsVisionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -87,15 +85,79 @@ class IdeficsVisionEmbeddings(nn.Module):
|
|||||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
# Heavily inspired from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/vit/modeling_vit.py#L82
|
||||||
batch_size = pixel_values.shape[0]
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
pos_embed = self.position_embedding(self.position_ids)
|
||||||
|
num_positions = pos_embed.shape[1] - 1
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return pos_embed
|
||||||
|
class_pos_embed = pos_embed[:, 0]
|
||||||
|
patch_pos_embed = pos_embed[:, 1:]
|
||||||
|
|
||||||
|
embed_dim = embeddings.shape[-1]
|
||||||
|
num_h_patches = height // self.config.patch_size
|
||||||
|
num_w_patches = width // self.config.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
|
||||||
|
sqrt_num_positions = math.sqrt(num_positions)
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
fp32_upcasting = patch_pos_embed.dtype == torch.bfloat16
|
||||||
|
if fp32_upcasting:
|
||||||
|
logger.warning_once(
|
||||||
|
"Upcasting patch_pos_embed to fp32 for interpolation since `upsample_bicubic2d_out_frame` in nn.functional.interpolate"
|
||||||
|
"is not implemented for 'torch.bfloat16' dtype. This will result in a slight overhead"
|
||||||
|
)
|
||||||
|
patch_pos_embed = patch_pos_embed.to(torch.float)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(num_h_patches / sqrt_num_positions, num_w_patches / sqrt_num_positions),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if fp32_upcasting:
|
||||||
|
patch_pos_embed = patch_pos_embed.to(torch.bfloat16)
|
||||||
|
if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
|
||||||
|
f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
|
||||||
|
)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||||
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
if not interpolate_pos_encoding:
|
||||||
|
if height != self.image_size or width != self.image_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
|
||||||
|
)
|
||||||
|
|
||||||
target_dtype = self.patch_embedding.weight.dtype
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||||
|
|
||||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
||||||
|
# add positional encoding to each token
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -387,12 +449,13 @@ class IdeficsVisionTransformer(nn.Module):
|
|||||||
self.encoder = IdeficsVisionEncoder(config)
|
self.encoder = IdeficsVisionEncoder(config)
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
|
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
@@ -408,7 +471,7 @@ class IdeficsVisionTransformer(nn.Module):
|
|||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
hidden_states = self.embeddings(pixel_values)
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
hidden_states = self.pre_layrnorm(hidden_states)
|
hidden_states = self.pre_layrnorm(hidden_states)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
|
|||||||
@@ -74,8 +74,6 @@ class IdeficsModelTester:
|
|||||||
num_labels=3,
|
num_labels=3,
|
||||||
scope=None,
|
scope=None,
|
||||||
modality_type_vocab_size=2,
|
modality_type_vocab_size=2,
|
||||||
add_multiple_images=False,
|
|
||||||
num_images=-1,
|
|
||||||
vision_embed_dim=32,
|
vision_embed_dim=32,
|
||||||
vision_patch_size=2,
|
vision_patch_size=2,
|
||||||
vision_image_size=30,
|
vision_image_size=30,
|
||||||
@@ -113,8 +111,6 @@ class IdeficsModelTester:
|
|||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
self.modality_type_vocab_size = modality_type_vocab_size
|
self.modality_type_vocab_size = modality_type_vocab_size
|
||||||
self.add_multiple_images = add_multiple_images
|
|
||||||
self.num_images = num_images
|
|
||||||
|
|
||||||
self.vision_embed_dim = vision_embed_dim
|
self.vision_embed_dim = vision_embed_dim
|
||||||
self.vision_patch_size = vision_patch_size
|
self.vision_patch_size = vision_patch_size
|
||||||
@@ -150,14 +146,17 @@ class IdeficsModelTester:
|
|||||||
# this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token
|
# this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token
|
||||||
self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1
|
self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self, num_images=1, interpolate_pos_encoding=False, image_expansion=0):
|
||||||
self.seq_length = 42
|
|
||||||
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
num_images = 2 if self.add_multiple_images else 1
|
|
||||||
pixel_values = floats_tensor(
|
pixel_values = floats_tensor(
|
||||||
[self.batch_size, num_images, self.num_channels, self.image_size, self.image_size]
|
[
|
||||||
|
self.batch_size,
|
||||||
|
num_images,
|
||||||
|
self.num_channels,
|
||||||
|
self.image_size + image_expansion,
|
||||||
|
self.image_size + image_expansion,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
@@ -166,8 +165,7 @@ class IdeficsModelTester:
|
|||||||
image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images])
|
image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images])
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
|
return (config, input_ids, input_mask, pixel_values, image_attention_mask, interpolate_pos_encoding)
|
||||||
return (config, input_ids, input_mask, pixel_values, image_attention_mask)
|
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return IdeficsConfig(
|
return IdeficsConfig(
|
||||||
@@ -188,7 +186,6 @@ class IdeficsModelTester:
|
|||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
num_labels=self.num_labels,
|
num_labels=self.num_labels,
|
||||||
modality_type_vocab_size=self.modality_type_vocab_size,
|
modality_type_vocab_size=self.modality_type_vocab_size,
|
||||||
num_images=self.num_images,
|
|
||||||
vision_config=self.vision_config,
|
vision_config=self.vision_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -199,17 +196,43 @@ class IdeficsModelTester:
|
|||||||
input_mask,
|
input_mask,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
image_attention_mask,
|
image_attention_mask,
|
||||||
|
interpolate_pos_encoding,
|
||||||
):
|
):
|
||||||
model = IdeficsModel(config=config)
|
model = IdeficsModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(
|
result = model(
|
||||||
input_ids, attention_mask=input_mask, pixel_values=pixel_values, image_attention_mask=image_attention_mask
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_attention_mask=image_attention_mask,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size)
|
result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_and_check_model_gen(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
pixel_values,
|
||||||
|
image_attention_mask,
|
||||||
|
interpolate_pos_encoding,
|
||||||
|
):
|
||||||
|
model = IdeficsForVisionText2Text(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
model.generate(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_attention_mask=image_attention_mask,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
max_length=self.seq_length + 2,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -218,12 +241,14 @@ class IdeficsModelTester:
|
|||||||
input_mask,
|
input_mask,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
image_attention_mask,
|
image_attention_mask,
|
||||||
|
interpolate_pos_encoding,
|
||||||
) = config_and_inputs
|
) = config_and_inputs
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": input_mask,
|
"attention_mask": input_mask,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"image_attention_mask": image_attention_mask,
|
"image_attention_mask": image_attention_mask,
|
||||||
|
"interpolate_pos_encoding": interpolate_pos_encoding,
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -268,10 +293,50 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
def test_model(self):
|
def test_model_single_image(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=1, interpolate_pos_encoding=False, image_expansion=0
|
||||||
|
)
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_multiple_images(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=2, interpolate_pos_encoding=False, image_expansion=0
|
||||||
|
)
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_with_image_pos_embeddings_interpolation_single_image(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=1, interpolate_pos_encoding=True, image_expansion=2
|
||||||
|
)
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=1, interpolate_pos_encoding=True, image_expansion=0
|
||||||
|
)
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_with_image_pos_embeddings_interpolation_multiple_images(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=2, interpolate_pos_encoding=True, image_expansion=2
|
||||||
|
)
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=2, interpolate_pos_encoding=True, image_expansion=0
|
||||||
|
)
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_generate_with_image_pos_embeddings_interpolation_single_image(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=1, interpolate_pos_encoding=True, image_expansion=2
|
||||||
|
)
|
||||||
|
self.model_tester.create_and_check_model_gen(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_generate_with_image_pos_embeddings_interpolation_multiple_images(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||||
|
num_images=2, interpolate_pos_encoding=True, image_expansion=2
|
||||||
|
)
|
||||||
|
self.model_tester.create_and_check_model_gen(*config_and_inputs)
|
||||||
|
|
||||||
def test_training(self):
|
def test_training(self):
|
||||||
if not self.model_tester.is_training:
|
if not self.model_tester.is_training:
|
||||||
return
|
return
|
||||||
@@ -289,8 +354,6 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
for k, v in inputs.items():
|
|
||||||
print(k, v.shape)
|
|
||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
@@ -416,7 +479,8 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = IdeficsModelTester(
|
self.model_tester = IdeficsModelTester(
|
||||||
self, modality_type_vocab_size=3, add_multiple_images=True, num_images=2
|
self,
|
||||||
|
modality_type_vocab_size=3,
|
||||||
)
|
)
|
||||||
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user