Fix pos_mask application and update tests accordingly (#27892)
* Fix pos_mask application and update tests accordingly * Fix style * Adding comments --------- Co-authored-by: Fernando Rodriguez <fernando.rodriguez@nielseniq.com>
This commit is contained in:
committed by
GitHub
parent
03b980990a
commit
57e9c83213
@@ -1949,6 +1949,7 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
|
|||||||
|
|
||||||
if mim_labels is not None:
|
if mim_labels is not None:
|
||||||
mim_labels = mim_labels[pos_mask]
|
mim_labels = mim_labels[pos_mask]
|
||||||
|
bool_masked_pos = bool_masked_pos[pos_mask]
|
||||||
|
|
||||||
# MMM Image Loss
|
# MMM Image Loss
|
||||||
if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
|
if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
|
||||||
@@ -1956,8 +1957,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
|
|||||||
end_index = image_masked_embeddings.size(1) - 1
|
end_index = image_masked_embeddings.size(1) - 1
|
||||||
sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
|
sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
|
||||||
|
|
||||||
if pos_mask is not None:
|
|
||||||
sequence_for_image = sequence_for_image[pos_mask]
|
|
||||||
if mim_labels is not None:
|
if mim_labels is not None:
|
||||||
mim_labels = self._resize_to_2d(mim_labels)
|
mim_labels = self._resize_to_2d(mim_labels)
|
||||||
bool_masked_pos = self._resize_to_2d(bool_masked_pos)
|
bool_masked_pos = self._resize_to_2d(bool_masked_pos)
|
||||||
@@ -1979,8 +1978,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
|
|||||||
if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
|
if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
|
||||||
sequence_for_text = multimodal_masked_embeddings
|
sequence_for_text = multimodal_masked_embeddings
|
||||||
sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
|
sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
|
||||||
if pos_mask is not None:
|
|
||||||
sequence_for_text = sequence_for_text[pos_mask]
|
|
||||||
|
|
||||||
if mlm_labels is not None:
|
if mlm_labels is not None:
|
||||||
mlm_labels = self._resize_to_2d(mlm_labels)
|
mlm_labels = self._resize_to_2d(mlm_labels)
|
||||||
|
|||||||
@@ -1313,8 +1313,12 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
|
|||||||
return_codebook_pixels=True,
|
return_codebook_pixels=True,
|
||||||
return_image_mask=True,
|
return_image_mask=True,
|
||||||
)
|
)
|
||||||
|
# Create a clone of the input_ids tensor that will be its masked version
|
||||||
inputs["input_ids_masked"] = inputs["input_ids"].clone()
|
inputs["input_ids_masked"] = inputs["input_ids"].clone()
|
||||||
|
# Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value
|
||||||
inputs["input_ids_masked"][0, 4:6] = 103
|
inputs["input_ids_masked"][0, 4:6] = 103
|
||||||
|
# MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored)
|
||||||
|
# except those that are masked, whose original values are stored
|
||||||
inputs["mlm_labels"] = inputs["input_ids"].clone()
|
inputs["mlm_labels"] = inputs["input_ids"].clone()
|
||||||
inputs["mlm_labels"][:, :] = -100
|
inputs["mlm_labels"][:, :] = -100
|
||||||
inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
|
inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
|
||||||
@@ -1338,3 +1342,54 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
|
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
|
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
|
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_with_itm_labels(self):
|
||||||
|
model_name = "facebook/flava-full"
|
||||||
|
model = FlavaForPreTraining.from_pretrained(model_name).to(torch_device)
|
||||||
|
processor = FlavaProcessor.from_pretrained(model_name)
|
||||||
|
torch.manual_seed(1)
|
||||||
|
random.seed(1)
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = processor(
|
||||||
|
text=["a photo of a cat", "a photo of a dog"],
|
||||||
|
images=[image, image],
|
||||||
|
padding="max_length",
|
||||||
|
max_length=77,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_codebook_pixels=True,
|
||||||
|
return_image_mask=True,
|
||||||
|
)
|
||||||
|
# Create a clone of the input_ids tensor that will be its masked version
|
||||||
|
inputs["input_ids_masked"] = inputs["input_ids"].clone()
|
||||||
|
# Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value
|
||||||
|
inputs["input_ids_masked"][0, 4:6] = 103
|
||||||
|
# MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored)
|
||||||
|
# except those that are masked, whose original values are stored
|
||||||
|
inputs["mlm_labels"] = inputs["input_ids"].clone()
|
||||||
|
inputs["mlm_labels"][:, :] = -100
|
||||||
|
inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
|
||||||
|
# Manually create the itm_labels tensor that indicates if the image-text match.
|
||||||
|
# In this case, the firs pair matches and the second does not
|
||||||
|
inputs["itm_labels"] = torch.tensor([1, 0])
|
||||||
|
inputs = inputs.to(torch_device)
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.contrastive_logits_per_image.shape,
|
||||||
|
torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.input_ids.shape[0])),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.contrastive_logits_per_text.shape,
|
||||||
|
torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.pixel_values.shape[0])),
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
||||||
|
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
||||||
|
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
|
||||||
|
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.89590501, places=4)
|
||||||
|
self.assertAlmostEqual(outputs.loss.item(), 9.1995, places=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user