blip support for training (#21021)
* `blip` support for training * remove labels creation * remove unneeded `decoder_input_ids` creation * final changes - add colab link to documentation - reduction = mean for loss * fix nits * update link * clearer error message
This commit is contained in:
@@ -31,6 +31,10 @@ However, most existing pre-trained models only excel in either understanding-bas
|
|||||||
This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
|
This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
|
||||||
The original code can be found [here](https://github.com/salesforce/BLIP).
|
The original code can be found [here](https://github.com/salesforce/BLIP).
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
- [Jupyter notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_blip.ipynb) on how to fine-tune BLIP for image captioning on a custom dataset
|
||||||
|
|
||||||
|
|
||||||
## BlipConfig
|
## BlipConfig
|
||||||
|
|
||||||
|
|||||||
@@ -1014,6 +1014,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
|
|||||||
encoder_hidden_states=image_embeds,
|
encoder_hidden_states=image_embeds,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
reduction="mean",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@@ -1125,7 +1126,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
||||||
|
|
||||||
self.decoder_pad_token_id = config.text_config.pad_token_id
|
self.decoder_pad_token_id = config.text_config.pad_token_id
|
||||||
self.decoder_bos_token_id = config.text_config.bos_token_id
|
self.decoder_start_token_id = config.text_config.bos_token_id
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@@ -1133,6 +1134,19 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
def get_input_embeddings(self) -> nn.Module:
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
return self.vision_model.embeddings.patch_embedding
|
return self.vision_model.embeddings.patch_embedding
|
||||||
|
|
||||||
|
# Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
|
||||||
|
def _shift_right(self, input_ids):
|
||||||
|
pad_token_id = self.decoder_pad_token_id
|
||||||
|
|
||||||
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||||
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||||
|
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
||||||
|
|
||||||
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||||
|
|
||||||
|
return shifted_input_ids
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1168,8 +1182,14 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
|
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
```"""
|
```"""
|
||||||
|
if labels is None and decoder_input_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
|
||||||
|
" `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
|
||||||
|
" are using the model for inference make sure that `decoder_input_ids` is passed."
|
||||||
|
)
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
vision_outputs = self.vision_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@@ -1191,11 +1211,11 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
|
|
||||||
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
||||||
|
|
||||||
if decoder_input_ids is None:
|
if labels is not None and decoder_input_ids is None:
|
||||||
decoder_input_ids = torch.LongTensor([self.decoder_bos_token_id]).repeat((batch_size, 1))
|
# get decoder inputs from shifting lm labels to the right - this is used in training mode
|
||||||
|
decoder_input_ids = self._shift_right(labels)
|
||||||
if labels is None:
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
labels = decoder_input_ids.masked_fill(decoder_input_ids == self.decoder_pad_token_id, -100)
|
labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100)
|
||||||
|
|
||||||
answer_output = self.text_decoder(
|
answer_output = self.text_decoder(
|
||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
@@ -1204,10 +1224,13 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
reduction="none",
|
reduction="mean",
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
|
if labels is not None:
|
||||||
|
decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
|
||||||
|
else:
|
||||||
|
decoder_loss = None
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
||||||
@@ -1288,7 +1311,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
|
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
|
||||||
|
|
||||||
bos_ids = torch.full(
|
bos_ids = torch.full(
|
||||||
(question_embeds.size(0), 1), fill_value=self.decoder_bos_token_id, device=question_embeds.device
|
(question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.text_decoder.generate(
|
outputs = self.text_decoder.generate(
|
||||||
@@ -1330,8 +1353,16 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
|
|||||||
# image text matching head
|
# image text matching head
|
||||||
self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
|
self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
|
||||||
|
|
||||||
self.decoder_pad_token_id = config.text_config.pad_token_id
|
self.decoder_pad_token_id = (
|
||||||
self.decoder_bos_token_id = config.text_config.bos_token_id
|
config.text_config.pad_token_id
|
||||||
|
if not hasattr(config, "decoder_pad_token_id")
|
||||||
|
else config.decoder_pad_token_id
|
||||||
|
)
|
||||||
|
self.decoder_start_token_id = (
|
||||||
|
config.text_config.bos_token_id
|
||||||
|
if not hasattr(config, "decoder_start_token_id")
|
||||||
|
else config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|||||||
@@ -731,7 +731,7 @@ class BlipTextModel(BlipTextPreTrainedModel):
|
|||||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)))
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))).to(device)
|
||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
|||||||
@@ -521,7 +521,7 @@ class BlipModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
class BlipTextImageModelsModelTester:
|
class BlipTextRetrievalModelTester:
|
||||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||||
|
|
||||||
if text_kwargs is None:
|
if text_kwargs is None:
|
||||||
@@ -569,13 +569,319 @@ class BlipTextImageModelsModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BlipTextImageModelsModelTester:
|
||||||
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||||
|
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
|
||||||
|
self.parent = parent
|
||||||
|
self.text_model_tester = BlipTextModelTester(parent, **text_kwargs)
|
||||||
|
self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs)
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||||
|
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, attention_mask, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return BlipConfig.from_text_vision_configs(
|
||||||
|
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||||
|
model = BlipModel(config).to(torch_device).eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(input_ids, pixel_values, attention_mask)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class BlipVQAModelTest(unittest.TestCase):
|
||||||
|
all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else ()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = BlipModelTester(self)
|
||||||
|
|
||||||
|
def _prepare_inputs_for_vqa(self):
|
||||||
|
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
|
inputs_dict.pop("return_loss")
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
def test_class_name_consistency(self):
|
||||||
|
"""
|
||||||
|
Tests that all VQA models have a class name that ends with "ForQuestionAnswering"
|
||||||
|
"""
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(self.model_tester.get_config())
|
||||||
|
self.assertTrue(
|
||||||
|
model.__class__.__name__.endswith("ForQuestionAnswering"),
|
||||||
|
f"Class name should end with 'ForVisualQuestionAnswering' got {model.__class__.__name__}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
"""
|
||||||
|
Tests that all VQA models can be trained on a single batch
|
||||||
|
"""
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(self.model_tester.get_config()).to(torch_device)
|
||||||
|
model.train()
|
||||||
|
loss = model(**self._prepare_inputs_for_vqa()).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# verify the gradients are not None
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}")
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
"""
|
||||||
|
Test if the forward function has the expected arguments.
|
||||||
|
"""
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(self.model_tester.get_config())
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so args are the first n entries
|
||||||
|
args = list(signature.parameters.keys())
|
||||||
|
expected_args = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"labels",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
]
|
||||||
|
for arg in expected_args:
|
||||||
|
self.assertTrue(
|
||||||
|
arg in args,
|
||||||
|
f"Argument {arg} of forward function signature should include {arg}. Found {args}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (BlipForImageTextRetrieval,) if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_attention_outputs = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = BlipTextRetrievalModelTester(self)
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
expected_arg_names = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
]
|
||||||
|
expected_arg_names.extend(
|
||||||
|
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||||
|
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||||
|
else ["encoder_outputs"]
|
||||||
|
)
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
else:
|
||||||
|
expected_arg_names = ["input_ids"] if model_class != BlipForConditionalGeneration else ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
if not self.model_tester.is_training:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes[:-1]:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
# hardcode labels to be the same as input_ids
|
||||||
|
inputs["labels"] = inputs["input_ids"]
|
||||||
|
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
if not self.model_tester.is_training:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes[:-1]:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.use_cache = False
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.train()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
# hardcode labels to be the same as input_ids
|
||||||
|
inputs["labels"] = inputs["input_ids"]
|
||||||
|
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# override as the `logit_scale` parameter initilization is different for Blip
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
# check if `logit_scale` is initilized as per the original implementation
|
||||||
|
if name == "logit_scale":
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
param.data.item(),
|
||||||
|
np.log(1 / 0.07),
|
||||||
|
delta=1e-3,
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
|
if not self.test_torchscript:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.torchscript = True
|
||||||
|
configs_no_init.return_dict = False
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
pixel_values = inputs_dict["pixel_values"] # Blip needs pixel_values
|
||||||
|
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
||||||
|
except RuntimeError:
|
||||||
|
self.fail("Couldn't trace module.")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.jit.save(traced_model, pt_file_name)
|
||||||
|
except Exception:
|
||||||
|
self.fail("Couldn't save module.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
loaded_model = torch.jit.load(pt_file_name)
|
||||||
|
except Exception:
|
||||||
|
self.fail("Couldn't load module.")
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
loaded_model.to(torch_device)
|
||||||
|
loaded_model.eval()
|
||||||
|
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
models_equal = True
|
||||||
|
for layer_name, p1 in model_state_dict.items():
|
||||||
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
def test_load_vision_text_config(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# Save BlipConfig and check if we can load BlipVisionConfig from it
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
config.save_pretrained(tmp_dir_name)
|
||||||
|
vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name)
|
||||||
|
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||||
|
|
||||||
|
# Save BlipConfig and check if we can load BlipTextConfig from it
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
config.save_pretrained(tmp_dir_name)
|
||||||
|
text_config = BlipTextConfig.from_pretrained(tmp_dir_name)
|
||||||
|
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_name in BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
model = BlipModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
BlipForConditionalGeneration,
|
BlipForConditionalGeneration,
|
||||||
BlipForQuestionAnswering,
|
BlipForQuestionAnswering,
|
||||||
BlipForImageTextRetrieval,
|
|
||||||
)
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
@@ -648,6 +954,10 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
# hardcode labels to be the same as input_ids
|
||||||
|
inputs["labels"] = inputs["input_ids"]
|
||||||
|
|
||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
@@ -665,6 +975,10 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
# hardcode labels to be the same as input_ids
|
||||||
|
inputs["labels"] = inputs["input_ids"]
|
||||||
|
|
||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user