From 59407bbeb31fff8340938768051c9daabd38d7a7 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Wed, 14 Sep 2022 11:45:21 +0200 Subject: [PATCH] Add Deformable DETR (#17281) * First draft * More improvements * Improve model, add custom CUDA code * Import torch before * Add script that imports custom layer * Add everything in new ops directory * Import custom layer in modeling file * Fix ARCHIVE_MAP typo * Creating the custom kernel on the fly. * Import custom layer in modeling file * More improvements * Fix CUDA loading * More improvements * Improve conversion script * Improve conversion script * Make it work until encoder_outputs * Make forward pass work * More improvements * Make logits match original implementation * Make implementation also support single_scale model * Add support for single_scale and dilation checkpoint * Add support for with_box_refine model * Support also two stage model * Improve tests * Fix more tests * Make more tests pass * Upload all models to the hub * Clean up some code * Improve decoder outputs * Rename intermediate hidden states and reference points * Improve model outputs * Move tests to dedicated folder * Improve model outputs * Fix retain_grad test * Improve docs * Clean up and make test_initialization pass * Improve variable names * Add copied from statements * Improve docs * Fix style * Improve docs * Improve docs, move tests to model folder * Fix rebase * Remove DetrForSegmentation from auto mapping * Apply suggestions from code review * Improve variable names and docstrings * Apply some more suggestions from code review * Apply suggestion from code review * better docs and variables names * hint to num_queries and two_stage confusion * remove asserts and code refactor * add exception if two_stage is True and with_box_refine is False * use f-strings * Improve docs and variable names * Fix code quality * Fix rebase * Add require_torch_gpu decorator * Add pip install ninja to CI jobs * Apply suggestion of @sgugger * Remove DeformableDetrForObjectDetection from auto mapping * Remove DeformableDetrModel from auto mapping * Add model to toctree * Add model back to mappings, skip model in pipeline tests * Apply @sgugger's suggestion * Fix imports in the init * Fix copies * Add CPU implementation * Comment out GPU function * Undo previous change * Apply more suggestions * Remove require_torch_gpu annotator * Fix quality * Add logger.info * Fix logger * Fix variable names * Fix initializaztion * Add missing initialization * Update checkpoint name * Add model to doc tests * Add CPU/GPU equivalence test * Add Deformable DETR to pipeline tests * Skip model for object detection pipeline Co-authored-by: Nicolas Patry Co-authored-by: Nouamane Tazi Co-authored-by: Sylvain Gugger --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.mdx | 2 + docs/source/en/model_doc/deformable_detr.mdx | 50 + setup.py | 1 + src/transformers/__init__.py | 24 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/deformable_detr/__init__.py | 61 + .../configuration_deformable_detr.py | 227 ++ .../convert_deformable_detr_to_pytorch.py | 237 ++ .../custom_kernel/cpu/ms_deform_attn_cpu.cpp | 40 + .../custom_kernel/cpu/ms_deform_attn_cpu.h | 32 + .../custom_kernel/cuda/ms_deform_attn_cuda.cu | 156 ++ .../cuda/ms_deform_attn_cuda.cuh | 1467 ++++++++++ .../custom_kernel/cuda/ms_deform_attn_cuda.h | 29 + .../cuda/ms_deform_im2col_cuda.cuh | 1327 +++++++++ .../custom_kernel/ms_deform_attn.h | 61 + .../deformable_detr/custom_kernel/vision.cpp | 16 + .../models/deformable_detr/load_custom.py | 51 + .../modeling_deformable_detr.py | 2465 +++++++++++++++++ src/transformers/models/detr/modeling_detr.py | 143 +- .../models/maskformer/modeling_maskformer.py | 61 +- .../models/yolos/modeling_yolos.py | 51 +- .../utils/dummy_timm_and_vision_objects.py | 24 + src/transformers/utils/dummy_timm_objects.py | 32 - tests/models/deformable_detr/__init__.py | 0 .../test_modeling_deformable_detr.py | 625 +++++ .../test_pipelines_object_detection.py | 6 + utils/check_repo.py | 2 + utils/documentation_tests.txt | 1 + 36 files changed, 7048 insertions(+), 156 deletions(-) create mode 100644 docs/source/en/model_doc/deformable_detr.mdx create mode 100644 src/transformers/models/deformable_detr/__init__.py create mode 100644 src/transformers/models/deformable_detr/configuration_deformable_detr.py create mode 100644 src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py create mode 100644 src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.cpp create mode 100644 src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.h create mode 100644 src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cu create mode 100644 src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cuh create mode 100644 src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.h create mode 100644 src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_im2col_cuda.cuh create mode 100644 src/transformers/models/deformable_detr/custom_kernel/ms_deform_attn.h create mode 100644 src/transformers/models/deformable_detr/custom_kernel/vision.cpp create mode 100644 src/transformers/models/deformable_detr/load_custom.py create mode 100755 src/transformers/models/deformable_detr/modeling_deformable_detr.py delete mode 100644 src/transformers/utils/dummy_timm_objects.py create mode 100644 tests/models/deformable_detr/__init__.py create mode 100644 tests/models/deformable_detr/test_modeling_deformable_detr.py diff --git a/README.md b/README.md index 6edaf4f012..570f12ac44 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch. +1. **[Deformable DETR](https://huggingface.co/docs/transformers/main/model_doc/deformable_detr)** (from SenseTime Research) released with the paper [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) by Xizhou Zhu, Weijie Su, Lewei Lu, Bin Li, Xiaogang Wang, Jifeng Dai. 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. diff --git a/README_ko.md b/README_ko.md index 18c3009758..c6016624a8 100644 --- a/README_ko.md +++ b/README_ko.md @@ -237,6 +237,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch. +1. **[Deformable DETR](https://huggingface.co/docs/transformers/main/model_doc/deformable_detr)** (from SenseTime Research) released with the paper [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) by Xizhou Zhu, Weijie Su, Lewei Lu, Bin Li, Xiaogang Wang, Jifeng Dai. 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. diff --git a/README_zh-hans.md b/README_zh-hans.md index 7771026565..f3c07bfb36 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -261,6 +261,7 @@ conda install -c huggingface transformers 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。 1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (来自 Berkeley/Facebook/Google) 伴随论文 [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) 由 Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch 发布。 +1. **[Deformable DETR](https://huggingface.co/docs/transformers/main/model_doc/deformable_detr)** (来自 SenseTime Research) 伴随论文 [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) 由 Xizhou Zhu, Weijie Su, Lewei Lu, Bin Li, Xiaogang Wang, Jifeng Dai 发布。 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (来自 Facebook) 伴随论文 [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) 由 Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou 发布。 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (来自 Facebook) 伴随论文 [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) 由 Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko 发布。 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (来自 Microsoft Research) 伴随论文 [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) 由 Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 196f474184..2ef861d059 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -273,6 +273,7 @@ conda install -c huggingface transformers 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch. +1. **[Deformable DETR](https://huggingface.co/docs/transformers/main/model_doc/deformable_detr)** (from SenseTime Research) released with the paper [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) by Xizhou Zhu, Weijie Su, Lewei Lu, Bin Li, Xiaogang Wang, Jifeng Dai. 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c4c5f2162a..b59fdfbc46 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -364,6 +364,8 @@ title: ConvNeXT - local: model_doc/cvt title: CvT + - local: model_doc/deformable_detr + title: Deformable DETR - local: model_doc/deit title: DeiT - local: model_doc/detr diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index fb2ff2d418..265fe39f25 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -77,6 +77,7 @@ The documentation is organized into five sections: 1. **[DeBERTa](model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[Decision Transformer](model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch. +1. **[Deformable DETR](model_doc/deformable_detr)** (from SenseTime Research) released with the paper [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) by Xizhou Zhu, Weijie Su, Lewei Lu, Bin Li, Xiaogang Wang, Jifeng Dai. 1. **[DeiT](model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. 1. **[DialoGPT](model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. @@ -223,6 +224,7 @@ Flax), PyTorch, and/or TensorFlow. | DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | DeBERTa-v2 | ✅ | ✅ | ✅ | ✅ | ❌ | | Decision Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | +| Deformable DETR | ❌ | ❌ | ✅ | ❌ | ❌ | | DeiT | ❌ | ❌ | ✅ | ✅ | ❌ | | DETR | ❌ | ❌ | ✅ | ❌ | ❌ | | DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/deformable_detr.mdx b/docs/source/en/model_doc/deformable_detr.mdx new file mode 100644 index 0000000000..7997b2f19d --- /dev/null +++ b/docs/source/en/model_doc/deformable_detr.mdx @@ -0,0 +1,50 @@ + + +# Deformable DETR + +## Overview + +The Deformable DETR model was proposed in [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) by Xizhou Zhu, Weijie Su, Lewei Lu, Bin Li, Xiaogang Wang, Jifeng Dai. +Deformable DETR mitigates the slow convergence issues and limited feature spatial resolution of the original [DETR](detr) by leveraging a new deformable attention module which only attends to a small set of key sampling points around a reference. + +The abstract from the paper is the following: + +*DETR has been recently proposed to eliminate the need for many hand-designed components in object detection while demonstrating good performance. However, it suffers from slow convergence and limited feature spatial resolution, due to the limitation of Transformer attention modules in processing image feature maps. To mitigate these issues, we proposed Deformable DETR, whose attention modules only attend to a small set of key sampling points around a reference. Deformable DETR can achieve better performance than DETR (especially on small objects) with 10 times less training epochs. Extensive experiments on the COCO benchmark demonstrate the effectiveness of our approach.* + +Tips: + +- One can use the [`AutoFeatureExtractor`] API to prepare images (and optional targets) for the model. This will instantiate a [`DetrFeatureExtractor`] behind the scenes. +- Training Deformable DETR is equivalent to training the original [DETR](detr) model. Demo notebooks can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR). + + + + Deformable DETR architecture. Taken from the original paper. + +This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/fundamentalvision/Deformable-DETR). + +## DeformableDetrConfig + +[[autodoc]] DeformableDetrConfig + + +## DeformableDetrModel + +[[autodoc]] DeformableDetrModel + - forward + + +## DeformableDetrForObjectDetection + +[[autodoc]] DeformableDetrForObjectDetection + - forward \ No newline at end of file diff --git a/setup.py b/setup.py index 3145272ef6..84bd8f5d6e 100644 --- a/setup.py +++ b/setup.py @@ -411,6 +411,7 @@ setup( url="https://github.com/huggingface/transformers", package_dir={"": "src"}, packages=find_packages("src"), + package_data={"transformers": ["py.typed", "*.cu", "*.cpp", "*.cuh", "*.h"]}, zip_safe=False, extras_require=extras, entry_points={"console_scripts": ["transformers-cli=transformers.commands.transformers_cli:main"]}, diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ee23c79db6..2671b37d8e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -187,6 +187,7 @@ _import_structure = { "models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], "models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], "models.decision_transformer": ["DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "DecisionTransformerConfig"], + "models.deformable_detr": ["DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeformableDetrConfig"], "models.deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"], "models.detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"], "models.dialogpt": [], @@ -682,12 +683,20 @@ try: if not (is_timm_available() and is_vision_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_timm_objects + from .utils import dummy_timm_and_vision_objects - _import_structure["utils.dummy_timm_objects"] = [ - name for name in dir(dummy_timm_objects) if not name.startswith("_") + _import_structure["utils.dummy_timm_and_vision_objects"] = [ + name for name in dir(dummy_timm_and_vision_objects) if not name.startswith("_") ] else: + _import_structure["models.deformable_detr"].extend( + [ + "DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST", + "DeformableDetrForObjectDetection", + "DeformableDetrModel", + "DeformableDetrPreTrainedModel", + ] + ) _import_structure["models.detr"].extend( [ "DETR_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3072,6 +3081,7 @@ if TYPE_CHECKING: DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, DecisionTransformerConfig, ) + from .models.deformable_detr import DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DeformableDetrConfig from .models.deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer @@ -3502,8 +3512,14 @@ if TYPE_CHECKING: if not (is_timm_available() and is_vision_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils.dummy_timm_objects import * + from .utils.dummy_timm_and_vision_objects import * else: + from .models.deformable_detr import ( + DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST, + DeformableDetrForObjectDetection, + DeformableDetrModel, + DeformableDetrPreTrainedModel, + ) from .models.detr import ( DETR_PRETRAINED_MODEL_ARCHIVE_LIST, DetrForObjectDetection, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0e2bad475a..6a206bb968 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -47,6 +47,7 @@ from . import ( deberta, deberta_v2, decision_transformer, + deformable_detr, deit, detr, dialogpt, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 785518fbcd..ae0e88bd4a 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -53,6 +53,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("deberta", "DebertaConfig"), ("deberta-v2", "DebertaV2Config"), ("decision_transformer", "DecisionTransformerConfig"), + ("deformable_detr", "DeformableDetrConfig"), ("deit", "DeiTConfig"), ("detr", "DetrConfig"), ("distilbert", "DistilBertConfig"), @@ -182,6 +183,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( ("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("deformable_detr", "DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -307,6 +309,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("deberta", "DeBERTa"), ("deberta-v2", "DeBERTa-v2"), ("decision_transformer", "Decision Transformer"), + ("deformable_detr", "Deformable DETR"), ("deit", "DeiT"), ("detr", "DETR"), ("dialogpt", "DialoGPT"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 625b79db06..015fd132ef 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -43,6 +43,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ("cvt", "ConvNextFeatureExtractor"), ("data2vec-audio", "Wav2Vec2FeatureExtractor"), ("data2vec-vision", "BeitFeatureExtractor"), + ("deformable_detr", "DetrFeatureExtractor"), ("deit", "DeiTFeatureExtractor"), ("detr", "DetrFeatureExtractor"), ("detr", "DetrFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 9423692c34..9edfae0c89 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -53,6 +53,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("deberta-v2", "DebertaV2Model"), ("decision_transformer", "DecisionTransformerModel"), ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"), + ("deformable_detr", "DeformableDetrModel"), ("deit", "DeiTModel"), ("detr", "DetrModel"), ("distilbert", "DistilBertModel"), @@ -451,6 +452,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( [ # Model for Object Detection mapping + ("deformable_detr", "DeformableDetrForObjectDetection"), ("detr", "DetrForObjectDetection"), ("yolos", "YolosForObjectDetection"), ] diff --git a/src/transformers/models/deformable_detr/__init__.py b/src/transformers/models/deformable_detr/__init__.py new file mode 100644 index 0000000000..f70d937c7f --- /dev/null +++ b/src/transformers/models/deformable_detr/__init__.py @@ -0,0 +1,61 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available + + +_import_structure = { + "configuration_deformable_detr": ["DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeformableDetrConfig"], +} + +try: + if not is_timm_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_deformable_detr"] = [ + "DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST", + "DeformableDetrForObjectDetection", + "DeformableDetrModel", + "DeformableDetrPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_deformable_detr import DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DeformableDetrConfig + + try: + if not is_timm_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_deformable_detr import ( + DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST, + DeformableDetrForObjectDetection, + DeformableDetrModel, + DeformableDetrPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/deformable_detr/configuration_deformable_detr.py b/src/transformers/models/deformable_detr/configuration_deformable_detr.py new file mode 100644 index 0000000000..bded523303 --- /dev/null +++ b/src/transformers/models/deformable_detr/configuration_deformable_detr.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Deformable DETR model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "SenseTime/deformable-detr": "https://huggingface.co/sensetime/deformable-detr/resolve/main/config.json", + # See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr +} + + +class DeformableDetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeformableDetrModel`]. It is used to instantiate + a Deformable DETR model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Deformable DETR + [SenseTime/deformable-detr](https://huggingface.co/SenseTime/deformable-detr) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_queries (`int`, *optional*, defaults to 300): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`DeformableDetrModel`] can detect in a single image. In case `two_stage` is set to `True`, we use + `two_stage_num_proposals` instead. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float`, *optional*, defaults to 1): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + encoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (`str`, *optional*, defaults to `"sine"`): + Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. + backbone (`str`, *optional*, defaults to `"resnet50"`): + Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a + list of all available models, see [this + page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). + dilation (`bool`, *optional*, defaults to `False`): + Whether to replace stride with dilation in the last convolutional block (DC5). + class_cost (`float`, *optional*, defaults to 1): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + mask_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Focal loss in the panoptic segmentation loss. + dice_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + num_feature_levels (`int`, *optional*, defaults to 4): + The number of input feature levels. + encoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the encoder. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + two_stage (`bool`, *optional*, defaults to `False`): + Whether to apply a two-stage deformable DETR, where the region proposals are also generated by a variant of + Deformable DETR, which are further fed into the decoder for iterative bounding box refinement. + two_stage_num_proposals (`int`, *optional*, defaults to 300): + The number of region proposals to be generated, in case `two_stage` is set to `True`. + with_box_refine (`bool`, *optional*, defaults to `False`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + + Examples: + + ```python + >>> from transformers import DeformableDetrModel, DeformableDetrConfig + + >>> # Initializing a Deformable DETR SenseTime/deformable-detr style configuration + >>> configuration = DeformableDetrConfig() + + >>> # Initializing a model from the SenseTime/deformable-detr style configuration + >>> model = DeformableDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "deformable_detr" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + num_queries=300, + max_position_embeddings=1024, + encoder_layers=6, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=1024, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + init_xavier_std=1.0, + return_intermediate=True, + auxiliary_loss=False, + position_embedding_type="sine", + backbone="resnet50", + dilation=False, + num_feature_levels=4, + encoder_n_points=4, + decoder_n_points=4, + two_stage=False, + two_stage_num_proposals=300, + with_box_refine=False, + class_cost=1, + bbox_cost=5, + giou_cost=2, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.1, + **kwargs + ): + self.num_queries = num_queries + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + self.backbone = backbone + self.dilation = dilation + # deformable attributes + self.num_feature_levels = num_feature_levels + self.encoder_n_points = encoder_n_points + self.decoder_n_points = decoder_n_points + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + self.with_box_refine = with_box_refine + if two_stage is True and with_box_refine is False: + raise ValueError("If two_stage is True, with_box_refine must be True.") + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.eos_coefficient = eos_coefficient + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model diff --git a/src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py b/src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py new file mode 100644 index 0000000000..85ee723fb3 --- /dev/null +++ b/src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Deformable DETR checkpoints.""" + + +import argparse +import json +from pathlib import Path + +import torch +from PIL import Image + +import requests +from huggingface_hub import cached_download, hf_hub_url +from transformers import DeformableDetrConfig, DeformableDetrForObjectDetection, DetrFeatureExtractor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def rename_key(orig_key): + if "backbone.0.body" in orig_key: + orig_key = orig_key.replace("backbone.0.body", "backbone.conv_encoder.model") + if "transformer" in orig_key: + orig_key = orig_key.replace("transformer.", "") + if "norm1" in orig_key: + if "encoder" in orig_key: + orig_key = orig_key.replace("norm1", "self_attn_layer_norm") + else: + orig_key = orig_key.replace("norm1", "encoder_attn_layer_norm") + if "norm2" in orig_key: + if "encoder" in orig_key: + orig_key = orig_key.replace("norm2", "final_layer_norm") + else: + orig_key = orig_key.replace("norm2", "self_attn_layer_norm") + if "norm3" in orig_key: + orig_key = orig_key.replace("norm3", "final_layer_norm") + if "linear1" in orig_key: + orig_key = orig_key.replace("linear1", "fc1") + if "linear2" in orig_key: + orig_key = orig_key.replace("linear2", "fc2") + if "query_embed" in orig_key: + orig_key = orig_key.replace("query_embed", "query_position_embeddings") + if "cross_attn" in orig_key: + orig_key = orig_key.replace("cross_attn", "encoder_attn") + + return orig_key + + +def read_in_q_k_v(state_dict): + # transformer decoder self-attention layers + for i in range(6): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"decoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"decoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + + return im + + +@torch.no_grad() +def convert_deformable_detr_checkpoint( + checkpoint_path, + single_scale, + dilation, + with_box_refine, + two_stage, + pytorch_dump_folder_path, + push_to_hub, +): + """ + Copy/paste/tweak model's weights to our Deformable DETR structure. + """ + + # load default config + config = DeformableDetrConfig() + # set config attributes + if single_scale: + config.num_feature_levels = 1 + config.dilation = dilation + config.with_box_refine = with_box_refine + config.two_stage = two_stage + # set labels + config.num_labels = 91 + repo_id = "datasets/huggingface/label-files" + filename = "coco-detection-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + # load feature extractor + feature_extractor = DetrFeatureExtractor(format="coco_detection") + + # prepare image + img = prepare_img() + encoding = feature_extractor(images=img, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + logger.info("Converting model...") + + # load original state dict + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + # rename keys + for key in state_dict.copy().keys(): + val = state_dict.pop(key) + state_dict[rename_key(key)] = val + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict) + # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them + prefix = "model." + for key in state_dict.copy().keys(): + if not key.startswith("class_embed") and not key.startswith("bbox_embed"): + val = state_dict.pop(key) + state_dict[prefix + key] = val + # finally, create HuggingFace model and load state dict + model = DeformableDetrForObjectDetection(config) + model.load_state_dict(state_dict) + model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + # verify our conversion + outputs = model(pixel_values.to(device)) + + expected_logits = torch.tensor( + [[-9.6645, -4.3449, -5.8705], [-9.7035, -3.8504, -5.0724], [-10.5634, -5.3379, -7.5116]] + ) + expected_boxes = torch.tensor([[0.8693, 0.2289, 0.2492], [0.3150, 0.5489, 0.5845], [0.5563, 0.7580, 0.8518]]) + + if single_scale: + expected_logits = torch.tensor( + [[-9.9051, -4.2541, -6.4852], [-9.6947, -4.0854, -6.8033], [-10.0665, -5.8470, -7.7003]] + ) + expected_boxes = torch.tensor([[0.7292, 0.4991, 0.5532], [0.7959, 0.2426, 0.4236], [0.7582, 0.3518, 0.4451]]) + + if single_scale and dilation: + expected_logits = torch.tensor( + [[-8.9652, -4.1074, -5.6635], [-9.0596, -4.9447, -6.6075], [-10.1178, -4.5275, -6.2671]] + ) + expected_boxes = torch.tensor([[0.7665, 0.4130, 0.4769], [0.8364, 0.1841, 0.3391], [0.6261, 0.3895, 0.7978]]) + + if with_box_refine: + expected_logits = torch.tensor( + [[-8.8895, -5.4187, -6.8153], [-8.4706, -6.1668, -7.6184], [-9.0042, -5.5359, -6.9141]] + ) + expected_boxes = torch.tensor([[0.7828, 0.2208, 0.4323], [0.0892, 0.5996, 0.1319], [0.5524, 0.6389, 0.8914]]) + + if with_box_refine and two_stage: + expected_logits = torch.tensor( + [[-6.7108, -4.3213, -6.3777], [-8.9014, -6.1799, -6.7240], [-6.9315, -4.4735, -6.2298]] + ) + expected_boxes = torch.tensor([[0.2583, 0.5499, 0.4683], [0.7652, 0.9068, 0.4882], [0.5490, 0.2763, 0.0564]]) + + print("Logits:", outputs.logits[0, :3, :3]) + + assert torch.allclose(outputs.logits[0, :3, :3], expected_logits.to(device), atol=1e-4) + assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes.to(device), atol=1e-4) + + print("Everything ok!") + + # Save model and feature extractor + logger.info(f"Saving PyTorch model and feature extractor to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + # Push to hub + if push_to_hub: + model_name = "deformable-detr" + model_name += "-single-scale" if single_scale else "" + model_name += "-dc5" if dilation else "" + model_name += "-with-box-refine" if with_box_refine else "" + model_name += "-two-stage" if two_stage else "" + print("Pushing model to hub...") + model.push_to_hub(repo_path_or_name=model_name, organization="nielsr", commit_message="Add model") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", + type=str, + default="/home/niels/checkpoints/deformable_detr/r50_deformable_detr-checkpoint.pth", + help="Path to Pytorch checkpoint (.pth file) you'd like to convert.", + ) + parser.add_argument("--single_scale", action="store_true", help="Whether to set config.num_features_levels = 1.") + parser.add_argument("--dilation", action="store_true", help="Whether to set config.dilation=True.") + parser.add_argument("--with_box_refine", action="store_true", help="Whether to set config.with_box_refine=True.") + parser.add_argument("--two_stage", action="store_true", help="Whether to set config.two_stage=True.") + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the folder to output PyTorch model.", + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + args = parser.parse_args() + convert_deformable_detr_checkpoint( + args.checkpoint_path, + args.single_scale, + args.dilation, + args.with_box_refine, + args.two_stage, + args.pytorch_dump_folder_path, + args.push_to_hub, + ) diff --git a/src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.cpp b/src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000000..388a73d22d --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,40 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} diff --git a/src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.h b/src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000000..7eac8c8bcd --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,32 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cu b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000000..8ea1d7fabe --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,156 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + +#pragma once +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} diff --git a/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cuh b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cuh new file mode 100644 index 0000000000..34f8ae9cb7 --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.cuh @@ -0,0 +1,1467 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} diff --git a/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.h b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000000..fbcf4543e6 --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,29 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); diff --git a/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_im2col_cuda.cuh b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000000..c0db0c88c9 --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} diff --git a/src/transformers/models/deformable_detr/custom_kernel/ms_deform_attn.h b/src/transformers/models/deformable_detr/custom_kernel/ms_deform_attn.h new file mode 100644 index 0000000000..119b1fa317 --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/ms_deform_attn.h @@ -0,0 +1,61 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/src/transformers/models/deformable_detr/custom_kernel/vision.cpp b/src/transformers/models/deformable_detr/custom_kernel/vision.cpp new file mode 100644 index 0000000000..6ce3875568 --- /dev/null +++ b/src/transformers/models/deformable_detr/custom_kernel/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} \ No newline at end of file diff --git a/src/transformers/models/deformable_detr/load_custom.py b/src/transformers/models/deformable_detr/load_custom.py new file mode 100644 index 0000000000..d2a8bc0cb2 --- /dev/null +++ b/src/transformers/models/deformable_detr/load_custom.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Loading of Deformable DETR's CUDA kernels""" + +import os + + +def load_cuda_kernels(): + from torch.utils.cpp_extension import load + + root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_kernel") + src_files = [ + os.path.join(root, filename) + for filename in [ + "vision.cpp", + os.path.join("cpu", "ms_deform_attn_cpu.cpp"), + os.path.join("cuda", "ms_deform_attn_cuda.cu"), + ] + ] + + load( + "MultiScaleDeformableAttention", + src_files, + # verbose=True, + with_cuda=True, + extra_include_paths=[root], + # build_directory=os.path.dirname(os.path.realpath(__file__)), + extra_cflags=["-DWITH_CUDA=1"], + extra_cuda_cflags=[ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + ) + + import MultiScaleDeformableAttention as MSDA + + return MSDA diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py new file mode 100755 index 0000000000..acd4d40124 --- /dev/null +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -0,0 +1,2465 @@ +# coding=utf-8 +# Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Deformable DETR model.""" + + +import copy +import math +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_scipy_available, + is_timm_available, + is_torch_cuda_available, + is_vision_available, + replace_return_docstrings, + requires_backends, +) +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_deformable_detr import DeformableDetrConfig +from .load_custom import load_cuda_kernels + + +logger = logging.get_logger(__name__) + +# Move this to not compile only when importing, this needs to happen later, like in __init__. +if is_torch_cuda_available(): + logger.info("Loading custom CUDA kernels...") + MultiScaleDeformableAttention = load_cuda_kernels() +else: + MultiScaleDeformableAttention = None + + +class MultiScaleDeformableAttentionFunction(Function): + @staticmethod + def forward( + context, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + context.im2col_step = im2col_step + output = MultiScaleDeformableAttention.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + context.im2col_step, + ) + context.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + def backward(context, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = context.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + context.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + +if is_vision_available(): + from transformers.models.detr.feature_extraction_detr import center_to_corners_format + +if is_timm_available(): + from timm import create_model + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeformableDetrConfig" +_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr" + +DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "sensetime/deformable-detr", + # See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr +] + + +@dataclass +class DeformableDetrDecoderOutput(ModelOutput): + """ + Base class for outputs of the DeformableDetrDecoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class DeformableDetrModelOutput(ModelOutput): + """ + Base class for outputs of the Deformable DETR encoder-decoder model. + + Args: + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + """ + + init_reference_points: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + + +@dataclass +class DeformableDetrObjectDetectionOutput(ModelOutput): + """ + Output type of [`DeformableDetrForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~AutoFeatureExtractor.post_process`] to retrieve the unnormalized bounding + boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4, + 4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average + in the self-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + auxiliary_outputs: Optional[List[Dict]] = None + init_reference_points: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + enc_outputs_class: Optional = None + enc_outputs_coord_logits: Optional = None + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DeformableDetr +class DeformableDetrFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr +def replace_batch_norm(m, name=""): + for attr_str in dir(m): + target_attr = getattr(m, attr_str) + if isinstance(target_attr, nn.BatchNorm2d): + frozen = DeformableDetrFrozenBatchNorm2d(target_attr.num_features) + bn = getattr(m, attr_str) + frozen.weight.data.copy_(bn.weight) + frozen.bias.data.copy_(bn.bias) + frozen.running_mean.data.copy_(bn.running_mean) + frozen.running_var.data.copy_(bn.running_var) + setattr(m, attr_str, frozen) + for n, ch in m.named_children(): + replace_batch_norm(ch, n) + + +class DeformableDetrTimmConvEncoder(nn.Module): + """ + Convolutional encoder (backbone) from the timm library. + + nn.BatchNorm2d layers are replaced by DeformableDetrFrozenBatchNorm2d as defined above. + """ + + def __init__(self, config): + super().__init__() + + kwargs = {} + if config.dilation: + kwargs["output_stride"] = 16 + + requires_backends(self, ["timm"]) + + out_indices = (2, 3, 4) if config.num_feature_levels > 1 else (4,) + backbone = create_model( + config.backbone, pretrained=True, features_only=True, out_indices=out_indices, **kwargs + ) + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.feature_info.channels() + self.strides = self.model.feature_info.reduction() + + if "resnet" in config.backbone: + for name, parameter in self.model.named_parameters(): + if "layer2" not in name and "layer3" not in name and "layer4" not in name: + parameter.requires_grad_(False) + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + """ + Outputs feature maps of latter stages C_3 through C_5 in ResNet if `config.num_feature_levels > 1`, otherwise + outputs feature maps of C_5. + """ + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values) + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DeformableDetr +class DeformableDetrConvModel(nn.Module): + """ + This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. + """ + + def __init__(self, conv_encoder, position_embedding): + super().__init__() + self.conv_encoder = conv_encoder + self.position_embedding = position_embedding + + def forward(self, pixel_values, pixel_mask): + # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples + out = self.conv_encoder(pixel_values, pixel_mask) + pos = [] + for feature_map, mask in out: + # position encoding + pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) + + return out, pos + + +# Copied from transformers.models.detr.modeling_detr._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): + """ + Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. + """ + batch_size, source_len = mask.size() + target_len = target_len if target_len is not None else source_len + + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class DeformableDetrSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.embedding_dim = embedding_dim + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, pixel_values, pixel_mask): + if pixel_mask is None: + raise ValueError("No pixel mask provided") + y_embed = pixel_mask.cumsum(1, dtype=torch.float32) + x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding +class DeformableDetrLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256): + super().__init__() + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) + + def forward(self, pixel_values, pixel_mask=None): + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + return pos + + +# Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->DeformableDetr +def build_position_encoding(config): + n_steps = config.d_model // 2 + if config.position_embedding_type == "sine": + # TODO find a better way of exposing other arguments + position_embedding = DeformableDetrSinePositionEmbedding(n_steps, normalize=True) + elif config.position_embedding_type == "learned": + position_embedding = DeformableDetrLearnedPositionEmbedding(n_steps) + else: + raise ValueError(f"Not supported {config.position_embedding_type}") + + return position_embedding + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +class DeformableDetrMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" + ) + dim_per_head = embed_dim // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = embed_dim + self.n_levels = n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(self.attention_weights.weight.data, 0.0) + nn.init.constant_(self.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(self.value_proj.weight.data) + nn.init.constant_(self.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(self.output_proj.weight.data) + nn.init.constant_(self.output_proj.bias.data, 0.0) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + try: + # GPU + output = MultiScaleDeformableAttentionFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + except Exception: + # CPU + output = ms_deform_attn_core_pytorch(value, spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + + return output, attention_weights + + +class DeformableDetrMultiheadAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # get queries, keys and values + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class DeformableDetrEncoderLayer(nn.Module): + def __init__(self, config: DeformableDetrConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = DeformableDetrMultiscaleDeformableAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + n_levels=config.num_feature_levels, + n_points=config.encoder_n_points, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Input to the layer. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Attention mask. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings, to be added to `hidden_states`. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes of the backbone feature maps. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class DeformableDetrDecoderLayer(nn.Module): + def __init__(self, config: DeformableDetrConfig): + super().__init__() + self.embed_dim = config.d_model + + # self-attention + self.self_attn = DeformableDetrMultiheadAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + # cross-attention + self.encoder_attn = DeformableDetrMultiscaleDeformableAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + n_levels=config.num_feature_levels, + n_points=config.decoder_n_points, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + # feedforward neural networks + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = second_residual + hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.detr.modeling_detr.DetrClassificationHead +class DeformableDetrClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, input_dim: int, inner_dim: int, num_classes: int, pooler_dropout: float): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class DeformableDetrPreTrainedModel(PreTrainedModel): + config_class = DeformableDetrConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module): + std = self.config.init_std + + if isinstance(module, DeformableDetrLearnedPositionEmbedding): + nn.init.uniform_(module.row_embeddings.weight) + nn.init.uniform_(module.column_embeddings.weight) + elif isinstance(module, DeformableDetrMultiscaleDeformableAttention): + module._reset_parameters() + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if hasattr(module, "reference_points") and not self.config.two_stage: + nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) + nn.init.constant_(module.reference_points.bias.data, 0.0) + if hasattr(module, "level_embed"): + nn.init.normal_(module.level_embed) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DeformableDetrDecoder): + module.gradient_checkpointing = value + + +DEFORMABLE_DETR_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeformableDetrConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEFORMABLE_DETR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. + + Pixel values can be obtained using [`AutoFeatureExtractor`]. See [`AutoFeatureExtractor.__call__`] for + details. + + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +class DeformableDetrEncoder(DeformableDetrPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a + [`DeformableDetrEncoderLayer`]. + + The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers. + + Args: + config: DeformableDetrConfig + """ + + def __init__(self, config: DeformableDetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """ + Get reference points for each feature map. Used in decoder. + + Args: + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Valid ratios of each feature map. + device (`torch.device`): + Device on which to create the tensors. + Returns: + `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` + """ + reference_points_list = [] + for level, (height, width) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), + ) + # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36 + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class DeformableDetrDecoder(DeformableDetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some tweaks for Deformable DETR: + + - `position_embeddings`, `reference_points`, `spatial_shapes` and `valid_ratios` are added to the forward pass. + - it also returns a stack of intermediate outputs and reference points from all decoding layers. + + Args: + config: DeformableDetrConfig + """ + + def __init__(self, config: DeformableDetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([DeformableDetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.gradient_checkpointing = False + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each self-attention layer. + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + + for idx, decoder_layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + ) + else: + if reference_points.shape[-1] != 2: + raise ValueError("Reference points' last dimension must be of size 2") + reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + encoder_hidden_states, + encoder_attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[idx](hidden_states) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + if reference_points.shape[-1] != 2: + raise ValueError( + f"Reference points' last dimension must be of size 2, but is {reference_points.shape[-1]}" + ) + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + intermediate += (hidden_states,) + intermediate_reference_points += (reference_points,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + intermediate = torch.stack(intermediate) + intermediate_reference_points = torch.stack(intermediate_reference_points) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return DeformableDetrDecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_reference_points=intermediate_reference_points, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The bare Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw + hidden-states without any specific head on top. + """, + DEFORMABLE_DETR_START_DOCSTRING, +) +class DeformableDetrModel(DeformableDetrPreTrainedModel): + def __init__(self, config: DeformableDetrConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = DeformableDetrTimmConvEncoder(config) + position_embeddings = build_position_encoding(config) + self.backbone = DeformableDetrConvModel(backbone, position_embeddings) + + # Create input projection layers + if config.num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.intermediate_channel_sizes[_] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1), + nn.GroupNorm(32, config.d_model), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, config.d_model), + ) + ) + in_channels = config.d_model + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1), + nn.GroupNorm(32, config.d_model), + ) + ] + ) + + if not config.two_stage: + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model * 2) + + self.encoder = DeformableDetrEncoder(config) + self.decoder = DeformableDetrDecoder(config) + + self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model)) + + if config.two_stage: + self.enc_output = nn.Linear(config.d_model, config.d_model) + self.enc_output_norm = nn.LayerNorm(config.d_model) + self.pos_trans = nn.Linear(config.d_model * 2, config.d_model * 2) + self.pos_trans_norm = nn.LayerNorm(config.d_model * 2) + else: + self.reference_points = nn.Linear(config.d_model, 2) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + def get_valid_ratio(self, mask): + """Get the valid ratio of all feature maps.""" + + _, height, width = mask.shape + valid_height = torch.sum(~mask[:, :, 0], 1) + valid_width = torch.sum(~mask[:, 0, :], 1) + valid_ratio_heigth = valid_height.float() / height + valid_ratio_width = valid_width.float() / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) + return valid_ratio + + def get_proposal_pos_embed(self, proposals): + """Get the position embedding of the proposals.""" + + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # batch_size, num_queries, 4 + proposals = proposals.sigmoid() * scale + # batch_size, num_queries, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): + """Generate the encoder output proposals from encoded enc_output. + + Args: + enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder. + padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`. + spatial_shapes (Tensor[num_feature_levels, 2]): Spatial shapes of the feature maps. + + Returns: + `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. + - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to + directly predict a bounding box. (without the need of a decoder) + - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse + sigmoid. + """ + batch_size = enc_output.shape[0] + proposals = [] + _cur = 0 + for level, (height, width) in enumerate(spatial_shapes): + mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1) + valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device), + torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale + width_heigth = torch.ones_like(grid) * 0.05 * (2.0**level) + proposal = torch.cat((grid, width_heigth), -1).view(batch_size, -1, 4) + proposals.append(proposal) + _cur += height * width + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) # inverse sigmoid + output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + # assign each pixel as an object query + object_query = enc_output + object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0)) + object_query = object_query.masked_fill(~output_proposals_valid, float(0)) + object_query = self.enc_output_norm(self.enc_output(object_query)) + return object_query, output_proposals + + @add_start_docstrings_to_model_forward(DEFORMABLE_DETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DeformableDetrModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values, + pixel_mask=None, + decoder_attention_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, DeformableDetrModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("SenseTime/deformable-detr") + >>> model = DeformableDetrModel.from_pretrained("SenseTime/deformable-detr") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device) + + # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper) + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # which is a list of tuples + features, position_embeddings_list = self.backbone(pixel_values, pixel_mask) + + # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + sources = [] + masks = [] + for level, (source, mask) in enumerate(features): + sources.append(self.input_proj[level](source)) + masks.append(mask) + if mask is None: + raise ValueError("No attention mask was provided") + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + for level in range(_len_sources, self.config.num_feature_levels): + if level == _len_sources: + source = self.input_proj[level](features[-1][0]) + else: + source = self.input_proj[level](sources[-1]) + mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone.position_embedding(source, mask).to(source.dtype) + sources.append(source) + masks.append(mask) + position_embeddings_list.append(pos_l) + + # Create queries + query_embeds = None + if not self.config.two_stage: + query_embeds = self.query_position_embeddings.weight + + # Prepare encoder inputs (by flattening) + source_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + source_flatten.append(source) + mask_flatten.append(mask) + source_flatten = torch.cat(source_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # revert valid_ratios + valid_ratios = ~valid_ratios.bool() + valid_ratios = valid_ratios.float() + + # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder + # Also provide spatial_shapes, level_start_index and valid_ratios + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=source_flatten, + attention_mask=mask_flatten, + position_embeddings=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, prepare decoder inputs + batch_size, _, num_channels = encoder_outputs[0].shape + enc_outputs_class = None + enc_outputs_coord_logits = None + if self.config.two_stage: + object_query_embedding, output_proposals = self.gen_encoder_output_proposals( + encoder_outputs[0], ~mask_flatten, spatial_shapes + ) + + # hack implementation for two-stage Deformable DETR + # apply a detection head to each pixel (A.4 in paper) + # linear projection for bounding box binary classification (i.e. foreground and background) + enc_outputs_class = self.decoder.class_embed[-1](object_query_embedding) + # 3-layer FFN to predict bounding boxes coordinates (bbox regression branch) + delta_bbox = self.decoder.bbox_embed[-1](object_query_embedding) + enc_outputs_coord_logits = delta_bbox + output_proposals + + # only keep top scoring `config.two_stage_num_proposals` proposals + topk = self.config.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_logits = torch.gather( + enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + ) + + topk_coords_logits = topk_coords_logits.detach() + reference_points = topk_coords_logits.sigmoid() + init_reference_points = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits))) + query_embed, target = torch.split(pos_trans_out, num_channels, dim=2) + else: + query_embed, target = torch.split(query_embeds, num_channels, dim=1) + query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1) + target = target.unsqueeze(0).expand(batch_size, -1, -1) + reference_points = self.reference_points(query_embed).sigmoid() + init_reference_points = reference_points + + decoder_outputs = self.decoder( + inputs_embeds=target, + position_embeddings=query_embed, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None) + tuple_outputs = (init_reference_points,) + decoder_outputs + encoder_outputs + enc_outputs + + return tuple_outputs + + return DeformableDetrModelOutput( + init_reference_points=init_reference_points, + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + ) + + +@add_start_docstrings( + """ + Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on + top, for tasks such as COCO detection. + """, + DEFORMABLE_DETR_START_DOCSTRING, +) +class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): + def __init__(self, config: DeformableDetrConfig): + super().__init__(config) + + # Deformable DETR encoder-decoder model + self.model = DeformableDetrModel(config) + + # Detection heads on top + self.class_embed = nn.Linear(config.d_model, config.num_labels) + self.bbox_embed = DeformableDetrMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(config.num_labels) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = (config.decoder_layers + 1) if config.two_stage else config.decoder_layers + if config.with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.model.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.model.decoder.bbox_embed = None + if config.two_stage: + # hack implementation for two-stage + self.model.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + + # Initialize weights and apply final processing + self.post_init() + + # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + @add_start_docstrings_to_model_forward(DEFORMABLE_DETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DeformableDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values, + pixel_mask=None, + decoder_attention_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, DeformableDetrForObjectDetection + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("SenseTime/deformable-detr") + >>> model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to COCO API + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = feature_extractor.post_process(outputs, target_sizes=target_sizes)[0] + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... # let's only keep detections with score > 0.7 + ... if score > 0.7: + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected cat with confidence 0.856 at location [342.19, 24.3, 640.02, 372.25] + Detected remote with confidence 0.739 at location [40.79, 72.78, 176.76, 117.25] + Detected cat with confidence 0.859 at location [16.5, 52.84, 318.25, 470.78] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through DETR base model to obtain encoder + decoder outputs + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2] + init_reference = outputs.init_reference_points if return_dict else outputs[0] + inter_references = outputs.intermediate_reference_points if return_dict else outputs[3] + + # class logits + predicted bounding boxes + outputs_classes = [] + outputs_coords = [] + + for level in range(hidden_states.shape[0]): + if level == 0: + reference = init_reference + else: + reference = inter_references[level - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[level](hidden_states[level]) + delta_bbox = self.bbox_embed[level](hidden_states[level]) + if reference.shape[-1] == 4: + outputs_coord_logits = delta_bbox + reference + elif reference.shape[-1] == 2: + delta_bbox[..., :2] += reference + outputs_coord_logits = delta_bbox + else: + raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}") + outputs_coord = outputs_coord_logits.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + logits = outputs_class[-1] + pred_boxes = outputs_coord[-1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + # First: create the matcher + matcher = DeformableDetrHungarianMatcher( + class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost + ) + # Second: create the criterion + losses = ["labels", "boxes", "cardinality"] + criterion = DeformableDetrLoss( + matcher=matcher, + num_classes=self.config.num_labels, + eos_coef=self.config.eos_coefficient, + losses=losses, + ) + criterion.to(self.device) + # Third: compute the losses, based on outputs and labels + outputs_loss = {} + outputs_loss["logits"] = logits + outputs_loss["pred_boxes"] = pred_boxes + if self.config.auxiliary_loss: + intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4] + outputs_class = self.class_embed(intermediate) + outputs_coord = self.bbox_embed(intermediate).sigmoid() + auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord) + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + if self.config.two_stage: + enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid() + outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord} + + loss_dict = criterion(outputs_loss, labels) + # Fourth: compute total loss, as a weighted sum of the various losses + weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient} + weight_dict["loss_giou"] = self.config.giou_loss_coefficient + if self.config.auxiliary_loss: + aux_weight_dict = {} + for i in range(self.config.decoder_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + auxiliary_outputs + outputs + else: + output = (logits, pred_boxes) + outputs + tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output + + return tuple_outputs + + dict_outputs = DeformableDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + intermediate_hidden_states=outputs.intermediate_hidden_states, + init_reference_points=outputs.init_reference_points, + intermediate_reference_points=outputs.intermediate_reference_points, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + ) + + return dict_outputs + + +# Copied from transformers.models.detr.modeling_detr.dice_loss +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs (0 for the negative class and 1 for the positive + class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + inputs (`torch.FloatTensor` of arbitrary shape): + The predictions for each example. + targets (`torch.FloatTensor` with the same shape as `inputs`) + A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class + and 1 for the positive class). + alpha (`float`, *optional*, defaults to `0.25`): + Optional weighting factor in the range (0,1) to balance positive vs. negative examples. + gamma (`int`, *optional*, defaults to `2`): + Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. + + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + # add modulating factor + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class DeformableDetrLoss(nn.Module): + """ + This class computes the losses for DeformableDetrForObjectDetection. The process happens in two steps: 1) we + compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of + matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, num_classes, eos_coef, losses, focal_alpha=0.25): + """ + Create the criterion. + + A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes` + parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id + is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass + `num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass + `num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion + https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223" + + Parameters: + matcher: module able to compute a matching between targets and proposals. + num_classes: number of object categories, omitting the special no-object category. + eos_coef: relative classification weight applied to the no-object category. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss. + """ + super().__init__() + + self.matcher = matcher + self.num_classes = num_classes + self.losses = losses + self.focal_alpha = focal_alpha + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + if "logits" not in outputs: + raise ValueError("No logits were found in the outputs") + + source_logits = outputs["logits"] + + idx = self._get_source_permutation_idx(indices) + target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full( + source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device + ) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros( + [source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1], + dtype=source_logits.dtype, + layout=source_logits.layout, + device=source_logits.device, + ) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:, :, :-1] + loss_ce = ( + sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) + * source_logits.shape[1] + ) + losses = {"loss_ce": loss_ce} + + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ + Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. + + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. + """ + logits = outputs["logits"] + device = logits.device + target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) + card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) + losses = {"cardinality_error": card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. + + Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes + are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + if "pred_boxes" not in outputs: + raise ValueError("No predicted boxes found in outputs") + + idx = self._get_source_permutation_idx(indices) + source_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") + + losses = {} + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag( + generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)) + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + def _get_source_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) + source_idx = torch.cat([source for (source, _) in indices]) + return batch_idx, source_idx + + def _get_target_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) + target_idx = torch.cat([target for (_, target) in indices]) + return batch_idx, target_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes): + loss_map = { + "labels": self.loss_labels, + "cardinality": self.loss_cardinality, + "boxes": self.loss_boxes, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + + return loss_map[loss](outputs, targets, indices, num_boxes) + + def forward(self, outputs, targets): + """ + This performs the loss computation. + + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + # (Niels): comment out function below, distributed training to be added + # if is_dist_avail_and_initialized(): + # torch.distributed.all_reduce(num_boxes) + # (Niels) in original implementation, num_boxes is divided by get_world_size() + num_boxes = torch.clamp(num_boxes, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "auxiliary_outputs" in outputs: + for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): + indices = self.matcher(auxiliary_outputs, targets) + for loss in self.losses: + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + if "enc_outputs" in outputs: + enc_outputs = outputs["enc_outputs"] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt["labels"] = torch.zeros_like(bt["labels"]) + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + kwargs = {} + if loss == "labels": + # Logging is enabled only for the last layer + kwargs["log"] = False + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = {k + "_enc": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead +class DeformableDetrMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher +class DeformableDetrHungarianMatcher(nn.Module): + """ + This class computes an assignment between the targets and the predictions of the network. + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more + predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + + Args: + class_cost: + The relative weight of the classification error in the matching cost. + bbox_cost: + The relative weight of the L1 error of the bounding box coordinates in the matching cost. + giou_cost: + The relative weight of the giou loss of the bounding box in the matching cost. + """ + + def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): + super().__init__() + requires_backends(self, ["scipy"]) + + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: + raise ValueError("All costs of the Matcher can't be 0") + + @torch.no_grad() + def forward(self, outputs, targets): + """ + Args: + outputs (`dict`): + A dictionary that contains at least these entries: + * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates. + targets (`List[dict]`): + A list of targets (len(targets) = batch_size), where each target is a dict containing: + * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of + ground-truth + objects in the target) containing the class labels + * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates. + + Returns: + `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + batch_size, num_queries = outputs["logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + target_ids = torch.cat([v["class_labels"] for v in targets]) + target_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + class_cost = -out_prob[:, target_ids] + + # Compute the L1 cost between boxes + bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) + + # Compute the giou cost between boxes + giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) + + # Final cost matrix + cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost + cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +# Copied from transformers.models.detr.modeling_detr._upcast +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +# Copied from transformers.models.detr.modeling_detr.box_area +def box_area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# Copied from transformers.models.detr.modeling_detr.box_iou +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +# Copied from transformers.models.detr.modeling_detr.generalized_box_iou +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: + `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") + iou, union = box_iou(boxes1, boxes2) + + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] + + return iou - (area - union) / area + + +# Copied from transformers.models.detr.modeling_detr._max_by_axis +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +# Copied from transformers.models.detr.modeling_detr.NestedTensor +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + if tensor_list[0].ndim == 3: + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + batch_size, num_channels, height, width = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("Only 3-dimensional tensors are supported") + return NestedTensor(tensor, mask) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 9d974f1a55..98f5f72f7e 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -273,7 +273,7 @@ class DetrFrozenBatchNorm2d(nn.Module): """ def __init__(self, n): - super(DetrFrozenBatchNorm2d, self).__init__() + super().__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) @@ -286,7 +286,7 @@ class DetrFrozenBatchNorm2d(nn.Module): if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] - super(DetrFrozenBatchNorm2d, self)._load_from_state_dict( + super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) @@ -387,14 +387,14 @@ class DetrConvModel(nn.Module): return out, pos -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len + batch_size, source_len = mask.size() + target_len = target_len if target_len is not None else source_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) inverted_mask = 1.0 - expanded_mask @@ -449,12 +449,12 @@ class DetrLearnedPositionEmbedding(nn.Module): self.column_embeddings = nn.Embedding(50, embedding_dim) def forward(self, pixel_values, pixel_mask=None): - h, w = pixel_values.shape[-2:] - i = torch.arange(w, device=pixel_values.device) - j = torch.arange(h, device=pixel_values.device) - x_emb = self.column_embeddings(i) - y_emb = self.row_embeddings(j) - pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1)], dim=-1) + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) pos = pos.permute(2, 0, 1) pos = pos.unsqueeze(0) pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) @@ -506,8 +506,8 @@ class DetrAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings @@ -526,7 +526,7 @@ class DetrAttention(nn.Module): # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, target_len, embed_dim = hidden_states.size() # add position embeddings to the hidden states before projecting to queries and keys if position_embeddings is not None: @@ -543,35 +543,36 @@ class DetrAttention(nn.Module): # get key, value proj if is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz) + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz) + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) - src_len = key_states.size(1) + source_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): + if attention_mask.size() != (batch_size, 1, target_len, source_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -580,8 +581,8 @@ class DetrAttention(nn.Module): # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to reshaped # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) else: attn_weights_reshaped = None @@ -589,15 +590,15 @@ class DetrAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) attn_output = self.out_proj(attn_output) @@ -632,7 +633,8 @@ class DetrEncoderLayer(nn.Module): Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -714,7 +716,8 @@ class DetrDecoderLayer(nn.Module): Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. position_embeddings (`torch.FloatTensor`, *optional*): position embeddings that are added to the queries and keys in the cross-attention layer. @@ -724,7 +727,8 @@ class DetrDecoderLayer(nn.Module): encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -957,7 +961,7 @@ class DetrEncoder(DetrPreTrainedModel): # expand attention_mask if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None @@ -1081,15 +1085,17 @@ class DetrDecoder(DetrPreTrainedModel): combined_attention_mask = None if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + encoder_attention_mask = _expand_mask( + encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] + ) # optional intermediate hidden states intermediate = () if self.config.auxiliary_loss else None @@ -1889,21 +1895,22 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs (0 for the negative class and 1 for the positive - class). - alpha: (optional) Weighting factor in range (0,1) to balance - positive vs negative examples. Default = -1 (no weighting). - gamma: Exponent of the modulating factor (1 - p_t) to - balance easy vs hard examples. + inputs (`torch.FloatTensor` of arbitrary shape): + The predictions for each example. + targets (`torch.FloatTensor` with the same shape as `inputs`) + A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class + and 1 for the positive class). + alpha (`float`, *optional*, defaults to `0.25`): + Optional weighting factor in the range (0,1) to balance positive vs. negative examples. + gamma (`int`, *optional*, defaults to `2`): + Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + # add modulating factor p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) @@ -1981,10 +1988,10 @@ class DetrLoss(nn.Module): """ logits = outputs["logits"] device = logits.device - tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) # Count the number of predictions that are NOT "no-object" (which is the last class) card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) - card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float()) + card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) losses = {"cardinality_error": card_err} return losses @@ -2191,19 +2198,19 @@ class DetrHungarianMatcher(nn.Module): out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target labels and boxes - tgt_ids = torch.cat([v["class_labels"] for v in targets]) - tgt_bbox = torch.cat([v["boxes"] for v in targets]) + target_ids = torch.cat([v["class_labels"] for v in targets]) + target_bbox = torch.cat([v["boxes"] for v in targets]) # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. - class_cost = -out_prob[:, tgt_ids] + class_cost = -out_prob[:, target_ids] # Compute the L1 cost between boxes - bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1) + bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) # Compute the giou cost between boxes - giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox)) + giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) # Final cost matrix cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost @@ -2267,15 +2274,17 @@ def generalized_box_iou(boxes1, boxes2): """ # degenerate boxes gives inf / nan results # so do an early check - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") iou, union = box_iou(boxes1, boxes2) - lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - wh = (rb - lt).clamp(min=0) # [N,M,2] - area = wh[:, :, 0] * wh[:, :, 1] + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] return iou - (area - union) / area @@ -2317,11 +2326,11 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): if tensor_list[0].ndim == 3: max_size = _max_by_axis([list(img.shape) for img in tensor_list]) batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape + batch_size, num_channels, height, width = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 1266dbfdad..110b11c653 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -1211,8 +1211,8 @@ class DetrAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings @@ -1231,7 +1231,7 @@ class DetrAttention(nn.Module): # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, target_len, embed_dim = hidden_states.size() # add position embeddings to the hidden states before projecting to queries and keys if position_embeddings is not None: @@ -1248,35 +1248,36 @@ class DetrAttention(nn.Module): # get key, value proj if is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz) + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz) + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) - src_len = key_states.size(1) + source_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): + if attention_mask.size() != (batch_size, 1, target_len, source_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -1285,8 +1286,8 @@ class DetrAttention(nn.Module): # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to reshaped # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) else: attn_weights_reshaped = None @@ -1294,15 +1295,15 @@ class DetrAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) attn_output = self.out_proj(attn_output) @@ -1351,7 +1352,8 @@ class DetrDecoderLayer(nn.Module): Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. position_embeddings (`torch.FloatTensor`, *optional*): position embeddings that are added to the queries and keys in the cross-attention layer. @@ -1361,7 +1363,8 @@ class DetrDecoderLayer(nn.Module): encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1416,14 +1419,14 @@ class DetrDecoderLayer(nn.Module): # Copied from transformers.models.detr.modeling_detr._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len + batch_size, source_len = mask.size() + target_len = target_len if target_len is not None else source_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) inverted_mask = 1.0 - expanded_mask diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index db70670b6c..d41e26c4e2 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -874,21 +874,22 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs (0 for the negative class and 1 for the positive - class). - alpha: (optional) Weighting factor in range (0,1) to balance - positive vs negative examples. Default = -1 (no weighting). - gamma: Exponent of the modulating factor (1 - p_t) to - balance easy vs hard examples. + inputs (`torch.FloatTensor` of arbitrary shape): + The predictions for each example. + targets (`torch.FloatTensor` with the same shape as `inputs`) + A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class + and 1 for the positive class). + alpha (`float`, *optional*, defaults to `0.25`): + Optional weighting factor in the range (0,1) to balance positive vs. negative examples. + gamma (`int`, *optional*, defaults to `2`): + Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + # add modulating factor p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) @@ -966,10 +967,10 @@ class YolosLoss(nn.Module): """ logits = outputs["logits"] device = logits.device - tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) # Count the number of predictions that are NOT "no-object" (which is the last class) card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) - card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float()) + card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) losses = {"cardinality_error": card_err} return losses @@ -1176,19 +1177,19 @@ class YolosHungarianMatcher(nn.Module): out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target labels and boxes - tgt_ids = torch.cat([v["class_labels"] for v in targets]) - tgt_bbox = torch.cat([v["boxes"] for v in targets]) + target_ids = torch.cat([v["class_labels"] for v in targets]) + target_bbox = torch.cat([v["boxes"] for v in targets]) # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. - class_cost = -out_prob[:, tgt_ids] + class_cost = -out_prob[:, target_ids] # Compute the L1 cost between boxes - bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1) + bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) # Compute the giou cost between boxes - giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox)) + giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) # Final cost matrix cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost @@ -1252,15 +1253,17 @@ def generalized_box_iou(boxes1, boxes2): """ # degenerate boxes gives inf / nan results # so do an early check - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") iou, union = box_iou(boxes1, boxes2) - lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - wh = (rb - lt).clamp(min=0) # [N,M,2] - area = wh[:, :, 0] * wh[:, :, 1] + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] return iou - (area - union) / area @@ -1302,11 +1305,11 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): if tensor_list[0].ndim == 3: max_size = _max_by_axis([list(img.shape) for img in tensor_list]) batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape + batch_size, num_channels, height, width = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False diff --git a/src/transformers/utils/dummy_timm_and_vision_objects.py b/src/transformers/utils/dummy_timm_and_vision_objects.py index 9a63196637..e990c33d2d 100644 --- a/src/transformers/utils/dummy_timm_and_vision_objects.py +++ b/src/transformers/utils/dummy_timm_and_vision_objects.py @@ -3,6 +3,30 @@ from ..utils import DummyObject, requires_backends +DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class DeformableDetrForObjectDetection(metaclass=DummyObject): + _backends = ["timm", "vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["timm", "vision"]) + + +class DeformableDetrModel(metaclass=DummyObject): + _backends = ["timm", "vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["timm", "vision"]) + + +class DeformableDetrPreTrainedModel(metaclass=DummyObject): + _backends = ["timm", "vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["timm", "vision"]) + + DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_timm_objects.py b/src/transformers/utils/dummy_timm_objects.py deleted file mode 100644 index c964d40315..0000000000 --- a/src/transformers/utils/dummy_timm_objects.py +++ /dev/null @@ -1,32 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import requires_backends - - -DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None - - -class DetrForObjectDetection: - def __init__(self, *args, **kwargs): - requires_backends(self, ["timm"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["timm"]) - - -class DetrForSegmentation: - def __init__(self, *args, **kwargs): - requires_backends(self, ["timm"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["timm"]) - - -class DetrModel: - def __init__(self, *args, **kwargs): - requires_backends(self, ["timm"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["timm"]) diff --git a/tests/models/deformable_detr/__init__.py b/tests/models/deformable_detr/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py new file mode 100644 index 0000000000..2c5bd4eda0 --- /dev/null +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -0,0 +1,625 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Deformable DETR model. """ + + +import inspect +import math +import unittest +from typing import Dict, List, Tuple + +from transformers import DeformableDetrConfig, is_timm_available, is_vision_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_timm, require_torch_gpu, require_vision, slow, torch_device + +from ...generation.test_generation_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor + + +if is_timm_available(): + import torch + + from transformers import DeformableDetrForObjectDetection, DeformableDetrModel + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class DeformableDetrModelTester: + def __init__( + self, + parent, + batch_size=8, + is_training=True, + use_labels=True, + hidden_size=256, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + num_queries=12, + num_channels=3, + image_size=196, + n_targets=8, + num_labels=91, + num_feature_levels=4, + encoder_n_points=2, + decoder_n_points=6, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_queries = num_queries + self.num_channels = num_channels + self.image_size = image_size + self.n_targets = n_targets + self.num_labels = num_labels + self.num_feature_levels = num_feature_levels + self.encoder_n_points = encoder_n_points + self.decoder_n_points = decoder_n_points + + # we also set the expected seq length for both encoder and decoder + self.encoder_seq_length = ( + math.ceil(self.image_size / 8) ** 2 + + math.ceil(self.image_size / 16) ** 2 + + math.ceil(self.image_size / 32) ** 2 + + math.ceil(self.image_size / 64) ** 2 + ) + self.decoder_seq_length = self.num_queries + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + # labels is a list of Dict (each Dict being the labels for a given example in the batch) + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device) + labels.append(target) + + config = self.get_config() + return config, pixel_values, pixel_mask, labels + + def get_config(self): + return DeformableDetrConfig( + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + num_queries=self.num_queries, + num_labels=self.num_labels, + num_feature_levels=self.num_feature_levels, + encoder_n_points=self.encoder_n_points, + decoder_n_points=self.decoder_n_points, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} + return config, inputs_dict + + def create_and_check_deformable_detr_model(self, config, pixel_values, pixel_mask, labels): + model = DeformableDetrModel(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size)) + + def create_and_check_deformable_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = DeformableDetrForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_timm +class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_timm_available() else () + is_encoder_decoder = True + test_torchscript = False + test_pruning = False + test_head_masking = False + test_missing_keys = False + + # special case for head models + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "DeformableDetrForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + target["masks"] = torch.ones( + self.model_tester.n_targets, + self.model_tester.image_size, + self.model_tester.image_size, + device=torch_device, + dtype=torch.float, + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = DeformableDetrModelTester(self) + self.config_tester = ConfigTester(self, config_class=DeformableDetrConfig, has_text_modality=False) + + def test_config(self): + # we don't test common_properties and arguments_init as these don't apply for Deformable DETR + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + + def test_deformable_detr_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deformable_detr_model(*config_and_inputs) + + def test_deformable_detr_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Deformable DETR does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Deformable DETR is not a generative model") + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="Deformable DETR does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + self.model_tester.num_feature_levels, + self.model_tester.encoder_n_points, + ], + ) + out_len = len(outputs) + + correct_outlen = 8 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + # Object Detection model returns pred_logits and pred_boxes + if model_class.__name__ == "DeformableDetrForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, self.model_tester.num_queries, self.model_tester.num_queries], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + self.model_tester.num_feature_levels, + self.model_tester.decoder_n_points, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + self.model_tester.num_feature_levels, + self.model_tester.encoder_n_points, + ], + ) + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + print("Model class:", model_class) + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence( + model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} + ) + + def test_retain_grad_hidden_states_attentions(self): + # removed retain_grad and grad on decoder_hidden_states, as queries don't require grad + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + # we take the second output since last_hidden_state is the second item + output = outputs[1] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + expected_arg_names = ["pixel_values", "pixel_mask"] + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" in arg_names + else [] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + else: + expected_arg_names = ["pixel_values", "pixel_mask"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_different_timm_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # let's pick a random timm backbone + config.backbone = "tf_mobilenetv3_small_075" + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "DeformableDetrForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + + self.assertTrue(outputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + print("Model class:", model_class) + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if param.requires_grad: + if ( + "level_embed" in name + or "sampling_offsets.bias" in name + or "value_proj" in name + or "output_proj" in name + or "reference_points" in name + ): + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + +TOLERANCE = 1e-4 + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_timm +@require_vision +@slow +class DeformableDetrModelIntegrationTests(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return AutoFeatureExtractor.from_pretrained("SenseTime/deformable-detr") if is_vision_available() else None + + def test_inference_object_detection_head(self): + model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr").to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device) + pixel_values = encoding["pixel_values"].to(torch_device) + pixel_mask = encoding["pixel_mask"].to(torch_device) + + with torch.no_grad(): + outputs = model(pixel_values, pixel_mask) + + expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_shape_logits) + + expected_logits = torch.tensor( + [[-9.6645, -4.3449, -5.8705], [-9.7035, -3.8504, -5.0724], [-10.5634, -5.3379, -7.5116]] + ).to(torch_device) + expected_boxes = torch.tensor( + [[0.8693, 0.2289, 0.2492], [0.3150, 0.5489, 0.5845], [0.5563, 0.7580, 0.8518]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4)) + + expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4)) + + def test_inference_object_detection_head_with_box_refine_two_stage(self): + model = DeformableDetrForObjectDetection.from_pretrained( + "SenseTime/deformable-detr-with-box-refine-two-stage" + ).to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device) + pixel_values = encoding["pixel_values"].to(torch_device) + pixel_mask = encoding["pixel_mask"].to(torch_device) + + with torch.no_grad(): + outputs = model(pixel_values, pixel_mask) + + expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_shape_logits) + + expected_logits = torch.tensor( + [[-6.7108, -4.3213, -6.3777], [-8.9014, -6.1799, -6.7240], [-6.9315, -4.4735, -6.2298]] + ).to(torch_device) + expected_boxes = torch.tensor( + [[0.2583, 0.5499, 0.4683], [0.7652, 0.9068, 0.4882], [0.5490, 0.2763, 0.0564]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4)) + + expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4)) + + @require_torch_gpu + def test_inference_object_detection_head_equivalence_cpu_gpu(self): + feature_extractor = self.default_feature_extractor + image = prepare_img() + encoding = feature_extractor(images=image, return_tensors="pt") + pixel_values = encoding["pixel_values"] + pixel_mask = encoding["pixel_mask"] + + # 1. run model on CPU + model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr-single-scale") + + with torch.no_grad(): + cpu_outputs = model(pixel_values, pixel_mask) + + # 2. run model on GPU + model.to("cuda") + + with torch.no_grad(): + gpu_outputs = model(pixel_values.to("cuda"), pixel_mask.to("cuda")) + + # 3. assert equivalence + for key in cpu_outputs.keys(): + assert torch.allclose(cpu_outputs[key], gpu_outputs[key].cpu(), atol=1e-4) + + expected_logits = torch.tensor( + [[-9.9051, -4.2541, -6.4852], [-9.6947, -4.0854, -6.8033], [-10.0665, -5.8470, -7.7003]] + ) + assert torch.allclose(cpu_outputs.logits[0, :3, :3], expected_logits, atol=1e-4) diff --git a/tests/pipelines/test_pipelines_object_detection.py b/tests/pipelines/test_pipelines_object_detection.py index 538f313151..b1d43f8a87 100644 --- a/tests/pipelines/test_pipelines_object_detection.py +++ b/tests/pipelines/test_pipelines_object_detection.py @@ -53,6 +53,12 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING def get_test_pipeline(self, model, tokenizer, feature_extractor): + if model.__class__.__name__ == "DeformableDetrForObjectDetection": + self.skipTest( + """Deformable DETR requires a custom CUDA kernel. + """ + ) + object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor) return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"] diff --git a/utils/check_repo.py b/utils/check_repo.py index 207592c91c..ea3b997f1f 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -46,6 +46,8 @@ PRIVATE_MODELS = [ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested + "DeformableDetrEncoder", # Building part of bigger (tested) model. + "DeformableDetrDecoder", # Building part of bigger (tested) model. "OPTDecoder", # Building part of bigger (tested) model. "DecisionTransformerGPT2Model", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model. diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index abe539cf46..0a27bdb527 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -28,6 +28,7 @@ src/transformers/models/data2vec/modeling_data2vec_audio.py src/transformers/models/data2vec/modeling_data2vec_vision.py src/transformers/models/deberta/modeling_deberta.py src/transformers/models/deberta_v2/modeling_deberta_v2.py +src/transformers/models/deformable_detr/modeling_deformable_detr.py src/transformers/models/deit/modeling_deit.py src/transformers/models/deit/modeling_tf_deit.py src/transformers/models/detr/modeling_detr.py