DeformableDETR support bfloat16 (#29232)

* Update ms_deform_attn_cuda.cu

* Update ms_deform_attn_cuda.cuh

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update modeling_deformable_detr.py

* python utils/check_copies.py --fix_and_overwrite

* Fix dtype missmatch error

* Update test_modeling_deformable_detr.py

* Update test_modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Support DeformableDETR with bfloat16

* Add test code

* Use AT_DISPATCH_FLOATING_TYPES_AND2

Use AT_DISPATCH_FLOATING_TYPES_AND2

* Update tests/models/deformable_detr/test_modeling_deformable_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/deformable_detr/test_modeling_deformable_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix not found require_torch_bf16 function

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Donggeun Yu
2024-03-04 23:18:09 +09:00
committed by GitHub
parent bcd23a54f1
commit ed74d97871
5 changed files with 34 additions and 5 deletions

View File

@@ -26,6 +26,7 @@ from transformers.testing_utils import (
require_timm,
require_torch,
require_torch_accelerator,
require_torch_bf16,
require_vision,
slow,
torch_device,
@@ -591,6 +592,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
output = model(**inputs)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
@require_torch_bf16
def create_and_check_model_bf16_forward(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config, torch_dtype=torch.bfloat16)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
output = model(**inputs)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
TOLERANCE = 1e-4