From b1ec745475dd723c6b5a7c62b2cbce0c7dc4abbd Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 26 Jun 2024 17:50:10 +0100 Subject: [PATCH] Fix RT-DETR inference with float16 and bfloat16 (#31639) * [run_slow] rt_detr * Fix positional embeddings and anchors dtypes * [run slow] rt_detr * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/rt_detr/modeling_rt_detr.py | 7 ++--- tests/models/rt_detr/test_modeling_rt_detr.py | 26 ++++++++++++++++++- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 807589cc62..26cf843357 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -1359,7 +1359,7 @@ class RTDetrHybridEncoder(nn.Module): if self.training or self.eval_size is None: pos_embed = self.build_2d_sincos_position_embedding( width, height, self.encoder_hidden_dim, self.positional_encoding_temperature - ).to(src_flatten.device) + ).to(src_flatten.device, src_flatten.dtype) else: pos_embed = None @@ -1801,12 +1801,13 @@ class RTDetrModel(RTDetrPreTrainedModel): batch_size = len(source_flatten) device = source_flatten.device + dtype = source_flatten.dtype # prepare input for decoder if self.training or self.config.anchor_image_size is None: - anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device) + anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype) else: - anchors, valid_mask = self.anchors.to(device), self.valid_mask.to(device) + anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype) # use the valid_mask to selectively retain values in the feature map where the mask is `True` memory = valid_mask.to(source_flatten.dtype) * source_flatten diff --git a/tests/models/rt_detr/test_modeling_rt_detr.py b/tests/models/rt_detr/test_modeling_rt_detr.py index 05ce68fb92..44647be5ac 100644 --- a/tests/models/rt_detr/test_modeling_rt_detr.py +++ b/tests/models/rt_detr/test_modeling_rt_detr.py @@ -18,6 +18,8 @@ import inspect import math import unittest +from parameterized import parameterized + from transformers import ( RTDetrConfig, RTDetrImageProcessor, @@ -25,7 +27,7 @@ from transformers import ( is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_torch, require_vision, torch_device +from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -606,6 +608,28 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_gpu + @slow + def test_inference_with_different_dtypes(self, torch_dtype_str): + torch_dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[torch_dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device).to(torch_dtype) + model.eval() + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(torch_dtype) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + TOLERANCE = 1e-4