Fix tests (#14703)
This commit is contained in:
@@ -860,7 +860,8 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]]
|
||||
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4))
|
||||
@@ -970,7 +971,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(inputs=patches)
|
||||
outputs = model(inputs=patches.to(torch_device))
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
|
||||
Reference in New Issue
Block a user