DeformableDetrModel support fp16 (#29013)
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@@ -134,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@@ -142,7 +142,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
|||||||
@@ -617,7 +617,8 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
|
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
|
||||||
thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
|
default_dtype = torch.get_default_dtype()
|
||||||
|
thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads)
|
||||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||||
grid_init = (
|
grid_init = (
|
||||||
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
||||||
@@ -1171,8 +1172,8 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|||||||
reference_points_list = []
|
reference_points_list = []
|
||||||
for level, (height, width) in enumerate(spatial_shapes):
|
for level, (height, width) in enumerate(spatial_shapes):
|
||||||
ref_y, ref_x = meshgrid(
|
ref_y, ref_x = meshgrid(
|
||||||
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
|
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
||||||
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
|
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
||||||
@@ -1540,15 +1541,15 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
||||||
param.requires_grad_(True)
|
param.requires_grad_(True)
|
||||||
|
|
||||||
def get_valid_ratio(self, mask):
|
def get_valid_ratio(self, mask, dtype=torch.float32):
|
||||||
"""Get the valid ratio of all feature maps."""
|
"""Get the valid ratio of all feature maps."""
|
||||||
|
|
||||||
_, height, width = mask.shape
|
_, height, width = mask.shape
|
||||||
valid_height = torch.sum(mask[:, :, 0], 1)
|
valid_height = torch.sum(mask[:, :, 0], 1)
|
||||||
valid_width = torch.sum(mask[:, 0, :], 1)
|
valid_width = torch.sum(mask[:, 0, :], 1)
|
||||||
valid_ratio_heigth = valid_height.float() / height
|
valid_ratio_height = valid_height.to(dtype) / height
|
||||||
valid_ratio_width = valid_width.float() / width
|
valid_ratio_width = valid_width.to(dtype) / width
|
||||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
|
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
|
||||||
return valid_ratio
|
return valid_ratio
|
||||||
|
|
||||||
def get_proposal_pos_embed(self, proposals):
|
def get_proposal_pos_embed(self, proposals):
|
||||||
@@ -1721,7 +1722,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
||||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||||
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
||||||
valid_ratios = valid_ratios.float()
|
valid_ratios = valid_ratios.float()
|
||||||
|
|
||||||
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
||||||
|
|||||||
@@ -1549,15 +1549,15 @@ class DetaModel(DetaPreTrainedModel):
|
|||||||
param.requires_grad_(True)
|
param.requires_grad_(True)
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
|
||||||
def get_valid_ratio(self, mask):
|
def get_valid_ratio(self, mask, dtype=torch.float32):
|
||||||
"""Get the valid ratio of all feature maps."""
|
"""Get the valid ratio of all feature maps."""
|
||||||
|
|
||||||
_, height, width = mask.shape
|
_, height, width = mask.shape
|
||||||
valid_height = torch.sum(mask[:, :, 0], 1)
|
valid_height = torch.sum(mask[:, :, 0], 1)
|
||||||
valid_width = torch.sum(mask[:, 0, :], 1)
|
valid_width = torch.sum(mask[:, 0, :], 1)
|
||||||
valid_ratio_heigth = valid_height.float() / height
|
valid_ratio_height = valid_height.to(dtype) / height
|
||||||
valid_ratio_width = valid_width.float() / width
|
valid_ratio_width = valid_width.to(dtype) / width
|
||||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
|
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
|
||||||
return valid_ratio
|
return valid_ratio
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_proposal_pos_embed
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_proposal_pos_embed
|
||||||
|
|||||||
@@ -583,6 +583,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def create_and_check_model_fp16_forward(self):
|
||||||
|
model_class = DeformableDetrForObjectDetection
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.half()
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
output = model(**inputs)["last_hidden_state"]
|
||||||
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user