Fix torch device issues (#20304)

* fix device issue

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-11-21 10:12:25 +01:00
committed by GitHub
parent d316037ad7
commit 8503cc7550
6 changed files with 8 additions and 8 deletions

View File

@@ -569,9 +569,9 @@ class DeformableDetrModelIntegrationTests(unittest.TestCase):
results = feature_extractor.post_process_object_detection(
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
)[0]
expected_scores = torch.tensor([0.7999, 0.7894, 0.6331, 0.4720, 0.4382])
expected_scores = torch.tensor([0.7999, 0.7894, 0.6331, 0.4720, 0.4382]).to(torch_device)
expected_labels = [17, 17, 75, 75, 63]
expected_slice_boxes = torch.tensor([16.5028, 52.8390, 318.2544, 470.7841])
expected_slice_boxes = torch.tensor([16.5028, 52.8390, 318.2544, 470.7841]).to(torch_device)
self.assertEqual(len(results["scores"]), 5)
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))