[Tests] Fix ViTMAE integration test (#15949)
* Fix test across both cpu and gpu * Fix typo
This commit is contained in:
@@ -401,6 +401,9 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_inference_for_pretraining(self):
|
def test_inference_for_pretraining(self):
|
||||||
# make random mask reproducible
|
# make random mask reproducible
|
||||||
|
# note that the same seed on CPU and on GPU doesn’t mean they spew the same random number sequences,
|
||||||
|
# as they both have fairly different PRNGs (for efficiency reasons).
|
||||||
|
# source: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735
|
||||||
torch.manual_seed(2)
|
torch.manual_seed(2)
|
||||||
|
|
||||||
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||||
@@ -417,8 +420,14 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_shape = torch.Size((1, 196, 768))
|
expected_shape = torch.Size((1, 196, 768))
|
||||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice_cpu = torch.tensor(
|
||||||
[[0.7366, -1.3663, -0.2844], [0.7919, -1.3839, -0.3241], [0.4313, -0.7168, -0.2878]]
|
[[0.7366, -1.3663, -0.2844], [0.7919, -1.3839, -0.3241], [0.4313, -0.7168, -0.2878]]
|
||||||
).to(torch_device)
|
)
|
||||||
|
expected_slice_gpu = torch.tensor(
|
||||||
|
[[0.8948, -1.0680, 0.0030], [0.9758, -1.1181, -0.0290], [1.0602, -1.1522, -0.0528]]
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4))
|
# set expected slice depending on device
|
||||||
|
expected_slice = expected_slice_cpu if torch_device == "cpu" else expected_slice_gpu
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user