Fix RT-DETR inference with float16 and bfloat16 (#31639)
* [run_slow] rt_detr * Fix positional embeddings and anchors dtypes * [run slow] rt_detr * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3f93fd0694
commit
b1ec745475
@@ -1359,7 +1359,7 @@ class RTDetrHybridEncoder(nn.Module):
|
|||||||
if self.training or self.eval_size is None:
|
if self.training or self.eval_size is None:
|
||||||
pos_embed = self.build_2d_sincos_position_embedding(
|
pos_embed = self.build_2d_sincos_position_embedding(
|
||||||
width, height, self.encoder_hidden_dim, self.positional_encoding_temperature
|
width, height, self.encoder_hidden_dim, self.positional_encoding_temperature
|
||||||
).to(src_flatten.device)
|
).to(src_flatten.device, src_flatten.dtype)
|
||||||
else:
|
else:
|
||||||
pos_embed = None
|
pos_embed = None
|
||||||
|
|
||||||
@@ -1801,12 +1801,13 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
|||||||
|
|
||||||
batch_size = len(source_flatten)
|
batch_size = len(source_flatten)
|
||||||
device = source_flatten.device
|
device = source_flatten.device
|
||||||
|
dtype = source_flatten.dtype
|
||||||
|
|
||||||
# prepare input for decoder
|
# prepare input for decoder
|
||||||
if self.training or self.config.anchor_image_size is None:
|
if self.training or self.config.anchor_image_size is None:
|
||||||
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device)
|
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
anchors, valid_mask = self.anchors.to(device), self.valid_mask.to(device)
|
anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype)
|
||||||
|
|
||||||
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
|
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
|
||||||
memory = valid_mask.to(source_flatten.dtype) * source_flatten
|
memory = valid_mask.to(source_flatten.dtype) * source_flatten
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ import inspect
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
RTDetrConfig,
|
RTDetrConfig,
|
||||||
RTDetrImageProcessor,
|
RTDetrImageProcessor,
|
||||||
@@ -25,7 +27,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_torch, require_vision, torch_device
|
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
|
||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -606,6 +608,28 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@parameterized.expand(["float32", "float16", "bfloat16"])
|
||||||
|
@require_torch_gpu
|
||||||
|
@slow
|
||||||
|
def test_inference_with_different_dtypes(self, torch_dtype_str):
|
||||||
|
torch_dtype = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}[torch_dtype_str]
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device).to(torch_dtype)
|
||||||
|
model.eval()
|
||||||
|
for key, tensor in inputs_dict.items():
|
||||||
|
if tensor.dtype == torch.float32:
|
||||||
|
inputs_dict[key] = tensor.to(torch_dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user