Fix FlavaForPreTrainingIntegrationTest CI test (#17232)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1219,6 +1219,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
||||||
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199)
|
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
|
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
|
||||||
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
|
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user