DeformableDETR support bfloat16 (#29232)
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py * Support DeformableDETR with bfloat16 * Add test code * Use AT_DISPATCH_FLOATING_TYPES_AND2 Use AT_DISPATCH_FLOATING_TYPES_AND2 * Update tests/models/deformable_detr/test_modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/deformable_detr/test_modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix not found require_torch_bf16 function --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from transformers.testing_utils import (
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_bf16,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -591,6 +592,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
output = model(**inputs)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
@require_torch_bf16
|
||||
def create_and_check_model_bf16_forward(self):
|
||||
model_class = DeformableDetrForObjectDetection
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
model = model_class(config, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
output = model(**inputs)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user