DeformableDETR support bfloat16 (#29232)
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py * Support DeformableDETR with bfloat16 * Add test code * Use AT_DISPATCH_FLOATING_TYPES_AND2 Use AT_DISPATCH_FLOATING_TYPES_AND2 * Update tests/models/deformable_detr/test_modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/deformable_detr/test_modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix not found require_torch_bf16 function --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@@ -134,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@@ -142,7 +142,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
|||||||
@@ -19,6 +19,14 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
const at::Tensor &attn_weight,
|
const at::Tensor &attn_weight,
|
||||||
const int im2col_step);
|
const int im2col_step);
|
||||||
|
|
||||||
|
at::Tensor ms_deform_attn_cuda_forward_bf16(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const int im2col_step);
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||||
const at::Tensor &value,
|
const at::Tensor &value,
|
||||||
const at::Tensor &spatial_shapes,
|
const at::Tensor &spatial_shapes,
|
||||||
@@ -27,3 +35,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
const at::Tensor &attn_weight,
|
const at::Tensor &attn_weight,
|
||||||
const at::Tensor &grad_output,
|
const at::Tensor &grad_output,
|
||||||
const int im2col_step);
|
const int im2col_step);
|
||||||
|
|
||||||
|
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const at::Tensor &grad_output,
|
||||||
|
const int im2col_step);
|
||||||
|
|||||||
@@ -1758,7 +1758,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||||
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
||||||
valid_ratios = valid_ratios.float()
|
|
||||||
|
|
||||||
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
||||||
# Also provide spatial_shapes, level_start_index and valid_ratios
|
# Also provide spatial_shapes, level_start_index and valid_ratios
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from transformers.testing_utils import (
|
|||||||
require_timm,
|
require_timm,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
|
require_torch_bf16,
|
||||||
require_vision,
|
require_vision,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@@ -591,6 +592,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
output = model(**inputs)["last_hidden_state"]
|
output = model(**inputs)["last_hidden_state"]
|
||||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
|
@require_torch_bf16
|
||||||
|
def create_and_check_model_bf16_forward(self):
|
||||||
|
model_class = DeformableDetrForObjectDetection
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
model = model_class(config, torch_dtype=torch.bfloat16)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
output = model(**inputs)["last_hidden_state"]
|
||||||
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user