[Blip] Remove redundant shift right (#23153)
* remove redundant shit right * fix failing tests * this time fix tests
This commit is contained in:
@@ -1121,19 +1121,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
def get_input_embeddings(self) -> nn.Module:
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
return self.vision_model.embeddings.patch_embedding
|
return self.vision_model.embeddings.patch_embedding
|
||||||
|
|
||||||
# Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
|
|
||||||
def _shift_right(self, input_ids):
|
|
||||||
pad_token_id = self.decoder_pad_token_id
|
|
||||||
|
|
||||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
|
||||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
|
||||||
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
|
||||||
|
|
||||||
# replace possible -100 values in labels by `pad_token_id`
|
|
||||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
|
||||||
|
|
||||||
return shifted_input_ids
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1215,10 +1202,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
|
|
||||||
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
||||||
|
|
||||||
if labels is not None and decoder_input_ids is None:
|
|
||||||
# get decoder inputs from shifting lm labels to the right - this is used in training mode
|
|
||||||
decoder_input_ids = self._shift_right(labels)
|
|
||||||
|
|
||||||
answer_output = self.text_decoder(
|
answer_output = self.text_decoder(
|
||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
|
|||||||
@@ -1335,30 +1335,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
|
|||||||
attentions=attns,
|
attentions=attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Adapted from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel._shift_right
|
|
||||||
def _shift_right(self, input_ids):
|
|
||||||
decoder_start_token_id = self.decoder_start_token_id
|
|
||||||
pad_token_id = self.decoder_pad_token_id
|
|
||||||
|
|
||||||
if decoder_start_token_id is None or pad_token_id is None:
|
|
||||||
raise ValueError("decoder_start_token_id and pad_token_id must be defined!")
|
|
||||||
|
|
||||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
|
||||||
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
|
|
||||||
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
|
||||||
|
|
||||||
# replace possible -100 values in labels by `pad_token_id`
|
|
||||||
shifted_input_ids = tf.where(
|
|
||||||
shifted_input_ids == -100,
|
|
||||||
tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
|
|
||||||
shifted_input_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
|
||||||
tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype))
|
|
||||||
|
|
||||||
return shifted_input_ids
|
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
@replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
||||||
@@ -1440,10 +1416,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
|
|||||||
|
|
||||||
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
||||||
|
|
||||||
if labels is not None and decoder_input_ids is None:
|
|
||||||
# get decoder inputs from shifting lm labels to the right - this is used in training mode
|
|
||||||
decoder_input_ids = self._shift_right(labels)
|
|
||||||
|
|
||||||
answer_output = self.text_decoder(
|
answer_output = self.text_decoder(
|
||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
|
|||||||
@@ -626,17 +626,73 @@ class BlipTextImageModelsModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BlipVQAModelTester:
|
||||||
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
|
||||||
|
self.parent = parent
|
||||||
|
self.text_model_tester = BlipTextModelTester(parent, **text_kwargs)
|
||||||
|
self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs)
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||||
|
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, attention_mask, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return BlipConfig.from_text_vision_configs(
|
||||||
|
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||||
|
model = BlipModel(config).to(torch_device).eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(input_ids, pixel_values, attention_mask)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": input_ids,
|
||||||
|
"decoder_input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
class BlipVQAModelTest(unittest.TestCase):
|
class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else ()
|
all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_attention_outputs = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BlipModelTester(self)
|
self.model_tester = BlipVQAModelTester(self)
|
||||||
|
|
||||||
def _prepare_inputs_for_vqa(self):
|
def _prepare_inputs_for_vqa(self):
|
||||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
|
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
||||||
inputs_dict.pop("return_loss")
|
inputs_dict.pop("return_loss")
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
@@ -658,7 +714,7 @@ class BlipVQAModelTest(unittest.TestCase):
|
|||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(self.model_tester.get_config()).to(torch_device)
|
model = model_class(self.model_tester.get_config()).to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
loss = model(**self._prepare_inputs_for_vqa()).loss
|
loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1]).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# verify the gradients are not None
|
# verify the gradients are not None
|
||||||
@@ -687,6 +743,18 @@ class BlipVQAModelTest(unittest.TestCase):
|
|||||||
f"Argument {arg} of forward function signature should include {arg}. Found {args}.",
|
f"Argument {arg} of forward function signature should include {arg}. Found {args}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
@@ -886,14 +954,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else ()
|
||||||
(
|
|
||||||
BlipForConditionalGeneration,
|
|
||||||
BlipForQuestionAnswering,
|
|
||||||
)
|
|
||||||
if is_torch_available()
|
|
||||||
else ()
|
|
||||||
)
|
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -526,17 +526,71 @@ class BlipTextImageModelsModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BlipVQAModelsModelTester:
|
||||||
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
|
||||||
|
self.parent = parent
|
||||||
|
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
|
||||||
|
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||||
|
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, attention_mask, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return BlipConfig.from_text_vision_configs(
|
||||||
|
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||||
|
model = TFBlipModel(config)
|
||||||
|
result = model(input_ids, pixel_values, attention_mask, training=False)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": input_ids,
|
||||||
|
"labels": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_vision
|
@require_vision
|
||||||
class BlipVQAModelTest(unittest.TestCase):
|
class TFBlipVQAModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TFBlipForQuestionAnswering,) if is_tf_available() else ()
|
all_model_classes = (TFBlipForQuestionAnswering,) if is_tf_available() else ()
|
||||||
|
test_head_masking = False
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_attention_outputs = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFBlipModelTester(self)
|
self.model_tester = BlipVQAModelsModelTester(self)
|
||||||
|
|
||||||
def _prepare_inputs_for_vqa(self):
|
def _prepare_inputs_for_vqa(self):
|
||||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
|
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
||||||
inputs_dict.pop("return_loss")
|
inputs_dict.pop("return_loss")
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
@@ -557,10 +611,34 @@ class BlipVQAModelTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(self.model_tester.get_config())
|
model = model_class(self.model_tester.get_config())
|
||||||
loss = model(**self._prepare_inputs_for_vqa(), training=True).loss
|
loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1], training=True).loss
|
||||||
|
|
||||||
self.assertIsNotNone(loss, "Loss should not be None")
|
self.assertIsNotNone(loss, "Loss should not be None")
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Tested in individual model tests")
|
||||||
|
def test_compile_tf_model(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Model doesn't have a clean loss output.")
|
||||||
|
def test_keras_fit(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
@@ -643,7 +721,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TFBlipForConditionalGeneration, TFBlipForQuestionAnswering) if is_tf_available() else ()
|
all_model_classes = (TFBlipForConditionalGeneration,) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
|
|||||||
Reference in New Issue
Block a user