Fix torch device issues (#20304)
* fix device issue Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -881,7 +881,7 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
|
|||||||
img_w = torch.Tensor([i[1] for i in target_sizes])
|
img_w = torch.Tensor([i[1] for i in target_sizes])
|
||||||
else:
|
else:
|
||||||
img_h, img_w = target_sizes.unbind(1)
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
||||||
boxes = boxes * scale_fct[:, None, :]
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|||||||
@@ -729,7 +729,7 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
|
|||||||
img_w = torch.Tensor([i[1] for i in target_sizes])
|
img_w = torch.Tensor([i[1] for i in target_sizes])
|
||||||
else:
|
else:
|
||||||
img_h, img_w = target_sizes.unbind(1)
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
||||||
boxes = boxes * scale_fct[:, None, :]
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|||||||
@@ -1103,7 +1103,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
else:
|
else:
|
||||||
img_h, img_w = target_sizes.unbind(1)
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
|
|
||||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
||||||
boxes = boxes * scale_fct[:, None, :]
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|||||||
@@ -694,7 +694,7 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
|
|||||||
else:
|
else:
|
||||||
img_h, img_w = target_sizes.unbind(1)
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
|
|
||||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
||||||
boxes = boxes * scale_fct[:, None, :]
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|||||||
@@ -511,9 +511,9 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
|
|||||||
results = feature_extractor.post_process_object_detection(
|
results = feature_extractor.post_process_object_detection(
|
||||||
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
||||||
)[0]
|
)[0]
|
||||||
expected_scores = torch.tensor([0.8330, 0.8313, 0.8039, 0.6829, 0.5355])
|
expected_scores = torch.tensor([0.8330, 0.8313, 0.8039, 0.6829, 0.5355]).to(torch_device)
|
||||||
expected_labels = [75, 17, 17, 75, 63]
|
expected_labels = [75, 17, 17, 75, 63]
|
||||||
expected_slice_boxes = torch.tensor([38.3089, 72.1022, 177.6293, 118.4512])
|
expected_slice_boxes = torch.tensor([38.3089, 72.1022, 177.6293, 118.4512]).to(torch_device)
|
||||||
|
|
||||||
self.assertEqual(len(results["scores"]), 5)
|
self.assertEqual(len(results["scores"]), 5)
|
||||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))
|
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))
|
||||||
|
|||||||
@@ -569,9 +569,9 @@ class DeformableDetrModelIntegrationTests(unittest.TestCase):
|
|||||||
results = feature_extractor.post_process_object_detection(
|
results = feature_extractor.post_process_object_detection(
|
||||||
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
||||||
)[0]
|
)[0]
|
||||||
expected_scores = torch.tensor([0.7999, 0.7894, 0.6331, 0.4720, 0.4382])
|
expected_scores = torch.tensor([0.7999, 0.7894, 0.6331, 0.4720, 0.4382]).to(torch_device)
|
||||||
expected_labels = [17, 17, 75, 75, 63]
|
expected_labels = [17, 17, 75, 75, 63]
|
||||||
expected_slice_boxes = torch.tensor([16.5028, 52.8390, 318.2544, 470.7841])
|
expected_slice_boxes = torch.tensor([16.5028, 52.8390, 318.2544, 470.7841]).to(torch_device)
|
||||||
|
|
||||||
self.assertEqual(len(results["scores"]), 5)
|
self.assertEqual(len(results["scores"]), 5)
|
||||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))
|
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user