Fix RT-DETR cache for generate_anchors (#31671)
* Fix cache and type conversion * Add test * Fixup * nit * [run slow] rt_detr * Fix test * Fixup * [run slow] rt_detr * Update src/transformers/models/rt_detr/modeling_rt_detr.py
This commit is contained in:
committed by
GitHub
parent
534cbf8a5d
commit
b97521614a
@@ -16,6 +16,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
@@ -630,6 +631,48 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@parameterized.expand(["float32", "float16", "bfloat16"])
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_inference_equivalence_for_static_and_dynamic_anchors(self, torch_dtype_str):
|
||||
torch_dtype = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}[torch_dtype_str]
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
h, w = inputs_dict["pixel_values"].shape[-2:]
|
||||
|
||||
# convert inputs to the desired dtype
|
||||
for key, tensor in inputs_dict.items():
|
||||
if tensor.dtype == torch.float32:
|
||||
inputs_dict[key] = tensor.to(torch_dtype)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_class(config).save_pretrained(tmpdirname)
|
||||
model_static = model_class.from_pretrained(
|
||||
tmpdirname, anchor_image_size=[h, w], device_map=torch_device, torch_dtype=torch_dtype
|
||||
).eval()
|
||||
model_dynamic = model_class.from_pretrained(
|
||||
tmpdirname, anchor_image_size=None, device_map=torch_device, torch_dtype=torch_dtype
|
||||
).eval()
|
||||
|
||||
self.assertIsNotNone(model_static.config.anchor_image_size)
|
||||
self.assertIsNone(model_dynamic.config.anchor_image_size)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class))
|
||||
outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4
|
||||
),
|
||||
f"Max diff: {(outputs_static.last_hidden_state - outputs_dynamic.last_hidden_state).abs().max()}",
|
||||
)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user