DeformableDetrModel support fp16 (#29013)
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto columns = output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data<int64_t>(),
|
||||
@@ -134,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
grad_output_g.data<scalar_t>(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
|
||||
@@ -72,7 +72,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto columns = output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data<int64_t>(),
|
||||
@@ -142,7 +142,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
grad_output_g.data<scalar_t>(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
|
||||
@@ -617,7 +617,8 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
|
||||
def _reset_parameters(self):
|
||||
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
|
||||
thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads)
|
||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||
grid_init = (
|
||||
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
||||
@@ -1171,8 +1172,8 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
reference_points_list = []
|
||||
for level, (height, width) in enumerate(spatial_shapes):
|
||||
ref_y, ref_x = meshgrid(
|
||||
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
||||
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
||||
@@ -1540,15 +1541,15 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
||||
param.requires_grad_(True)
|
||||
|
||||
def get_valid_ratio(self, mask):
|
||||
def get_valid_ratio(self, mask, dtype=torch.float32):
|
||||
"""Get the valid ratio of all feature maps."""
|
||||
|
||||
_, height, width = mask.shape
|
||||
valid_height = torch.sum(mask[:, :, 0], 1)
|
||||
valid_width = torch.sum(mask[:, 0, :], 1)
|
||||
valid_ratio_heigth = valid_height.float() / height
|
||||
valid_ratio_width = valid_width.float() / width
|
||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
|
||||
valid_ratio_height = valid_height.to(dtype) / height
|
||||
valid_ratio_width = valid_width.to(dtype) / width
|
||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
|
||||
return valid_ratio
|
||||
|
||||
def get_proposal_pos_embed(self, proposals):
|
||||
@@ -1721,7 +1722,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
||||
valid_ratios = valid_ratios.float()
|
||||
|
||||
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
||||
|
||||
@@ -1549,15 +1549,15 @@ class DetaModel(DetaPreTrainedModel):
|
||||
param.requires_grad_(True)
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
|
||||
def get_valid_ratio(self, mask):
|
||||
def get_valid_ratio(self, mask, dtype=torch.float32):
|
||||
"""Get the valid ratio of all feature maps."""
|
||||
|
||||
_, height, width = mask.shape
|
||||
valid_height = torch.sum(mask[:, :, 0], 1)
|
||||
valid_width = torch.sum(mask[:, 0, :], 1)
|
||||
valid_ratio_heigth = valid_height.float() / height
|
||||
valid_ratio_width = valid_width.float() / width
|
||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
|
||||
valid_ratio_height = valid_height.to(dtype) / height
|
||||
valid_ratio_width = valid_width.to(dtype) / width
|
||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
|
||||
return valid_ratio
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_proposal_pos_embed
|
||||
|
||||
@@ -583,6 +583,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_model_fp16_forward(self):
|
||||
model_class = DeformableDetrForObjectDetection
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
output = model(**inputs)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user