DeformableDetrModel support fp16 (#29013)
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -583,6 +583,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_model_fp16_forward(self):
|
||||
model_class = DeformableDetrForObjectDetection
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
output = model(**inputs)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user