DeformableDETR support bfloat16 (#29232)
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py * Support DeformableDETR with bfloat16 * Add test code * Use AT_DISPATCH_FLOATING_TYPES_AND2 Use AT_DISPATCH_FLOATING_TYPES_AND2 * Update tests/models/deformable_detr/test_modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/deformable_detr/test_modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix not found require_torch_bf16 function --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto columns = output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data<int64_t>(),
|
||||
@@ -134,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
grad_output_g.data<scalar_t>(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
|
||||
@@ -72,7 +72,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto columns = output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data<int64_t>(),
|
||||
@@ -142,7 +142,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
grad_output_g.data<scalar_t>(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
|
||||
@@ -19,6 +19,14 @@ at::Tensor ms_deform_attn_cuda_forward(
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step);
|
||||
|
||||
at::Tensor ms_deform_attn_cuda_forward_bf16(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step);
|
||||
|
||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
@@ -27,3 +35,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step);
|
||||
|
||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step);
|
||||
|
||||
@@ -1758,7 +1758,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
||||
valid_ratios = valid_ratios.float()
|
||||
|
||||
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
||||
# Also provide spatial_shapes, level_start_index and valid_ratios
|
||||
|
||||
@@ -26,6 +26,7 @@ from transformers.testing_utils import (
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_bf16,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -591,6 +592,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
output = model(**inputs)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
@require_torch_bf16
|
||||
def create_and_check_model_bf16_forward(self):
|
||||
model_class = DeformableDetrForObjectDetection
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
model = model_class(config, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
output = model(**inputs)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user