IDEFICS: allow interpolation of vision's pos embeddings (#26029)
* add pos embed interpolation for vision encoder * style * update config with interpolate_pos_encoding arg * fix imports formatting * take off copied from on vision embeddings * add test for image embeddings interpolation * add credit for interpolation code * Update src/transformers/models/idefics/configuration_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/idefics/vision.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix condition to check nbr image patches match shape of pos embeddings * use kwargs in the forward methods for interpolation * fix tests * have interpolate_pos_encoding default to False instead of None * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/idefics/configuration_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * take off for loop meant to print k,v * add interpolate_pos_encoding arg in prepare_inputs_for_generation * add test for interpolated generation * fix edge case num_patches == num_positions and height == width * add test for edge case * fix pos_embed in interpolate * allow interpolation in bf16 with upcasting * Update src/transformers/models/idefics/vision.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/idefics/vision.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add multiple images tests for interpolation and generation --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -74,8 +74,6 @@ class IdeficsModelTester:
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
modality_type_vocab_size=2,
|
||||
add_multiple_images=False,
|
||||
num_images=-1,
|
||||
vision_embed_dim=32,
|
||||
vision_patch_size=2,
|
||||
vision_image_size=30,
|
||||
@@ -113,8 +111,6 @@ class IdeficsModelTester:
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.modality_type_vocab_size = modality_type_vocab_size
|
||||
self.add_multiple_images = add_multiple_images
|
||||
self.num_images = num_images
|
||||
|
||||
self.vision_embed_dim = vision_embed_dim
|
||||
self.vision_patch_size = vision_patch_size
|
||||
@@ -150,14 +146,17 @@ class IdeficsModelTester:
|
||||
# this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token
|
||||
self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
self.seq_length = 42
|
||||
|
||||
def prepare_config_and_inputs(self, num_images=1, interpolate_pos_encoding=False, image_expansion=0):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
num_images = 2 if self.add_multiple_images else 1
|
||||
pixel_values = floats_tensor(
|
||||
[self.batch_size, num_images, self.num_channels, self.image_size, self.image_size]
|
||||
[
|
||||
self.batch_size,
|
||||
num_images,
|
||||
self.num_channels,
|
||||
self.image_size + image_expansion,
|
||||
self.image_size + image_expansion,
|
||||
]
|
||||
)
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
@@ -166,8 +165,7 @@ class IdeficsModelTester:
|
||||
image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images])
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return (config, input_ids, input_mask, pixel_values, image_attention_mask)
|
||||
return (config, input_ids, input_mask, pixel_values, image_attention_mask, interpolate_pos_encoding)
|
||||
|
||||
def get_config(self):
|
||||
return IdeficsConfig(
|
||||
@@ -188,7 +186,6 @@ class IdeficsModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
num_labels=self.num_labels,
|
||||
modality_type_vocab_size=self.modality_type_vocab_size,
|
||||
num_images=self.num_images,
|
||||
vision_config=self.vision_config,
|
||||
)
|
||||
|
||||
@@ -199,17 +196,43 @@ class IdeficsModelTester:
|
||||
input_mask,
|
||||
pixel_values,
|
||||
image_attention_mask,
|
||||
interpolate_pos_encoding,
|
||||
):
|
||||
model = IdeficsModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, pixel_values=pixel_values, image_attention_mask=image_attention_mask
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
pixel_values=pixel_values,
|
||||
image_attention_mask=image_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_model_gen(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
pixel_values,
|
||||
image_attention_mask,
|
||||
interpolate_pos_encoding,
|
||||
):
|
||||
model = IdeficsForVisionText2Text(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.generate(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
pixel_values=pixel_values,
|
||||
image_attention_mask=image_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
max_length=self.seq_length + 2,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -218,12 +241,14 @@ class IdeficsModelTester:
|
||||
input_mask,
|
||||
pixel_values,
|
||||
image_attention_mask,
|
||||
interpolate_pos_encoding,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"image_attention_mask": image_attention_mask,
|
||||
"interpolate_pos_encoding": interpolate_pos_encoding,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -268,10 +293,50 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
def test_model_single_image(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=1, interpolate_pos_encoding=False, image_expansion=0
|
||||
)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_multiple_images(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=2, interpolate_pos_encoding=False, image_expansion=0
|
||||
)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_with_image_pos_embeddings_interpolation_single_image(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=1, interpolate_pos_encoding=True, image_expansion=2
|
||||
)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=1, interpolate_pos_encoding=True, image_expansion=0
|
||||
)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_with_image_pos_embeddings_interpolation_multiple_images(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=2, interpolate_pos_encoding=True, image_expansion=2
|
||||
)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=2, interpolate_pos_encoding=True, image_expansion=0
|
||||
)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_generate_with_image_pos_embeddings_interpolation_single_image(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=1, interpolate_pos_encoding=True, image_expansion=2
|
||||
)
|
||||
self.model_tester.create_and_check_model_gen(*config_and_inputs)
|
||||
|
||||
def test_generate_with_image_pos_embeddings_interpolation_multiple_images(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(
|
||||
num_images=2, interpolate_pos_encoding=True, image_expansion=2
|
||||
)
|
||||
self.model_tester.create_and_check_model_gen(*config_and_inputs)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
@@ -289,8 +354,6 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
for k, v in inputs.items():
|
||||
print(k, v.shape)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@@ -416,7 +479,8 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = IdeficsModelTester(
|
||||
self, modality_type_vocab_size=3, add_multiple_images=True, num_images=2
|
||||
self,
|
||||
modality_type_vocab_size=3,
|
||||
)
|
||||
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user