Fix CI: test_inference_for_pretraining in ViTMAEModelTest (#16591)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-04-05 10:00:03 +02:00
committed by GitHub
parent 104c065277
commit 765bafb8e4

View File

@@ -561,7 +561,7 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
# forward pass # forward pass
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs, noise=torch.from_numpy(noise)) outputs = model(**inputs, noise=torch.from_numpy(noise).to(device=torch_device))
# verify the logits # verify the logits
expected_shape = torch.Size((1, 196, 768)) expected_shape = torch.Size((1, 196, 768))