Fix torch device issues (#20304)
* fix device issue Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -511,9 +511,9 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
|
||||
results = feature_extractor.post_process_object_detection(
|
||||
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.8330, 0.8313, 0.8039, 0.6829, 0.5355])
|
||||
expected_scores = torch.tensor([0.8330, 0.8313, 0.8039, 0.6829, 0.5355]).to(torch_device)
|
||||
expected_labels = [75, 17, 17, 75, 63]
|
||||
expected_slice_boxes = torch.tensor([38.3089, 72.1022, 177.6293, 118.4512])
|
||||
expected_slice_boxes = torch.tensor([38.3089, 72.1022, 177.6293, 118.4512]).to(torch_device)
|
||||
|
||||
self.assertEqual(len(results["scores"]), 5)
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user