Fix FlavaForPreTrainingIntegrationTest CI test (#17232)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-05-16 21:14:25 +02:00
committed by GitHub
parent 9b0d2860eb
commit 3fb82f74fd

View File

@@ -1219,6 +1219,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device) expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3)) self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199) self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4) self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4) self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)