From 8503cc755050c6ed5bc771e3244c29b71be1841e Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:12:25 +0100 Subject: [PATCH] Fix torch device issues (#20304) * fix device issue Co-authored-by: ydshieh --- .../conditional_detr/feature_extraction_conditional_detr.py | 2 +- .../deformable_detr/feature_extraction_deformable_detr.py | 2 +- src/transformers/models/detr/feature_extraction_detr.py | 2 +- src/transformers/models/yolos/feature_extraction_yolos.py | 2 +- .../models/conditional_detr/test_modeling_conditional_detr.py | 4 ++-- tests/models/deformable_detr/test_modeling_deformable_detr.py | 4 ++-- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py b/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py index 608b438479..01efb90542 100644 --- a/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py +++ b/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py @@ -881,7 +881,7 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac img_w = torch.Tensor([i[1] for i in target_sizes]) else: img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] results = [] diff --git a/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py b/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py index 3b5ad2cecd..a618106380 100644 --- a/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py +++ b/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py @@ -729,7 +729,7 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract img_w = torch.Tensor([i[1] for i in target_sizes]) else: img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] results = [] diff --git a/src/transformers/models/detr/feature_extraction_detr.py b/src/transformers/models/detr/feature_extraction_detr.py index 9898b26586..18c262fea4 100644 --- a/src/transformers/models/detr/feature_extraction_detr.py +++ b/src/transformers/models/detr/feature_extraction_detr.py @@ -1103,7 +1103,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): else: img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] results = [] diff --git a/src/transformers/models/yolos/feature_extraction_yolos.py b/src/transformers/models/yolos/feature_extraction_yolos.py index 350037eff3..de6db49434 100644 --- a/src/transformers/models/yolos/feature_extraction_yolos.py +++ b/src/transformers/models/yolos/feature_extraction_yolos.py @@ -694,7 +694,7 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) else: img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] results = [] diff --git a/tests/models/conditional_detr/test_modeling_conditional_detr.py b/tests/models/conditional_detr/test_modeling_conditional_detr.py index b2d5186004..667caa3840 100644 --- a/tests/models/conditional_detr/test_modeling_conditional_detr.py +++ b/tests/models/conditional_detr/test_modeling_conditional_detr.py @@ -511,9 +511,9 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase): results = feature_extractor.post_process_object_detection( outputs, threshold=0.3, target_sizes=[image.size[::-1]] )[0] - expected_scores = torch.tensor([0.8330, 0.8313, 0.8039, 0.6829, 0.5355]) + expected_scores = torch.tensor([0.8330, 0.8313, 0.8039, 0.6829, 0.5355]).to(torch_device) expected_labels = [75, 17, 17, 75, 63] - expected_slice_boxes = torch.tensor([38.3089, 72.1022, 177.6293, 118.4512]) + expected_slice_boxes = torch.tensor([38.3089, 72.1022, 177.6293, 118.4512]).to(torch_device) self.assertEqual(len(results["scores"]), 5) self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4)) diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 06823e7fe3..f69d8f15c1 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -569,9 +569,9 @@ class DeformableDetrModelIntegrationTests(unittest.TestCase): results = feature_extractor.post_process_object_detection( outputs, threshold=0.3, target_sizes=[image.size[::-1]] )[0] - expected_scores = torch.tensor([0.7999, 0.7894, 0.6331, 0.4720, 0.4382]) + expected_scores = torch.tensor([0.7999, 0.7894, 0.6331, 0.4720, 0.4382]).to(torch_device) expected_labels = [17, 17, 75, 75, 63] - expected_slice_boxes = torch.tensor([16.5028, 52.8390, 318.2544, 470.7841]) + expected_slice_boxes = torch.tensor([16.5028, 52.8390, 318.2544, 470.7841]).to(torch_device) self.assertEqual(len(results["scores"]), 5) self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))