[Blip] Remove redundant shift right (#23153)

* remove redundant shit right

* fix failing tests

* this time fix tests
This commit is contained in:
Younes Belkada
2023-05-19 19:14:16 +02:00
committed by GitHub
parent 847e5691a6
commit 3cb9309024
4 changed files with 154 additions and 60 deletions

View File

@@ -626,17 +626,73 @@ class BlipTextImageModelsModelTester:
return config, inputs_dict
class BlipVQAModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = BlipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs)
self.is_training = is_training
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return BlipConfig.from_text_vision_configs(
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = BlipModel(config).to(torch_device).eval()
with torch.no_grad():
result = model(input_ids, pixel_values, attention_mask)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"decoder_input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
return config, inputs_dict
@require_torch
@require_vision
class BlipVQAModelTest(unittest.TestCase):
class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else ()
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = False
def setUp(self):
self.model_tester = BlipModelTester(self)
self.model_tester = BlipVQAModelTester(self)
def _prepare_inputs_for_vqa(self):
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["labels"] = inputs_dict["input_ids"]
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict.pop("return_loss")
return inputs_dict
@@ -658,7 +714,7 @@ class BlipVQAModelTest(unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(self.model_tester.get_config()).to(torch_device)
model.train()
loss = model(**self._prepare_inputs_for_vqa()).loss
loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1]).loss
loss.backward()
# verify the gradients are not None
@@ -687,6 +743,18 @@ class BlipVQAModelTest(unittest.TestCase):
f"Argument {arg} of forward function signature should include {arg}. Found {args}.",
)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="BlipModel does not have input/output embeddings")
def test_model_common_attributes(self):
pass
@require_torch
class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
@@ -886,14 +954,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
BlipForConditionalGeneration,
BlipForQuestionAnswering,
)
if is_torch_available()
else ()
)
all_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_head_masking = False
test_pruning = False

View File

@@ -526,17 +526,71 @@ class BlipTextImageModelsModelTester:
return config, inputs_dict
class BlipVQAModelsModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
self.is_training = is_training
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return BlipConfig.from_text_vision_configs(
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = TFBlipModel(config)
result = model(input_ids, pixel_values, attention_mask, training=False)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"decoder_input_ids": input_ids,
"labels": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
return config, inputs_dict
@require_tf
@require_vision
class BlipVQAModelTest(unittest.TestCase):
class TFBlipVQAModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipForQuestionAnswering,) if is_tf_available() else ()
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
test_onnx = False
def setUp(self):
self.model_tester = TFBlipModelTester(self)
self.model_tester = BlipVQAModelsModelTester(self)
def _prepare_inputs_for_vqa(self):
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["labels"] = inputs_dict["input_ids"]
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict.pop("return_loss")
return inputs_dict
@@ -557,10 +611,34 @@ class BlipVQAModelTest(unittest.TestCase):
"""
for model_class in self.all_model_classes:
model = model_class(self.model_tester.get_config())
loss = model(**self._prepare_inputs_for_vqa(), training=True).loss
loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1], training=True).loss
self.assertIsNotNone(loss, "Loss should not be None")
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="BlipModel does not have input/output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="Tested in individual model tests")
def test_compile_tf_model(self):
pass
@unittest.skip("Model doesn't have a clean loss output.")
def test_keras_fit(self):
pass
@require_tf
class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
@@ -643,7 +721,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tf
class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipForConditionalGeneration, TFBlipForQuestionAnswering) if is_tf_available() else ()
all_model_classes = (TFBlipForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_pruning = False
test_resize_embeddings = False