DeformableDetrModel support fp16 (#29013)

* Update ms_deform_attn_cuda.cu

* Update ms_deform_attn_cuda.cuh

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update modeling_deformable_detr.py

* python utils/check_copies.py --fix_and_overwrite

* Fix dtype missmatch error

* Update test_modeling_deformable_detr.py

* Update test_modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Donggeun Yu
2024-02-15 21:31:09 +09:00
committed by GitHub
parent 83e96dc0ab
commit 5b6fa2306a
5 changed files with 29 additions and 16 deletions

View File

@@ -583,6 +583,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
loss = model(**inputs).loss
loss.backward()
def create_and_check_model_fp16_forward(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
model.to(torch_device)
model.half()
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
output = model(**inputs)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
TOLERANCE = 1e-4