Clean up CUDA kernels (#23455)
This commit is contained in:
@@ -13,16 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Loading of Deformable DETR's CUDA kernels"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_cuda_kernels():
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_kernel")
|
||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
||||
src_files = [
|
||||
os.path.join(root, filename)
|
||||
root / filename
|
||||
for filename in [
|
||||
"vision.cpp",
|
||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
||||
@@ -33,10 +33,8 @@ def load_cuda_kernels():
|
||||
load(
|
||||
"MultiScaleDeformableAttention",
|
||||
src_files,
|
||||
# verbose=True,
|
||||
with_cuda=True,
|
||||
extra_include_paths=[root],
|
||||
# build_directory=os.path.dirname(os.path.realpath(__file__)),
|
||||
extra_include_paths=[str(root)],
|
||||
extra_cflags=["-DWITH_CUDA=1"],
|
||||
extra_cuda_cflags=[
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -56,8 +56,8 @@ def load_cuda_kernels():
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
def append_root(files):
|
||||
src_folder = os.path.dirname(os.path.realpath(__file__))
|
||||
return [os.path.join(src_folder, file) for file in files]
|
||||
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso"
|
||||
return [src_folder / file for file in files]
|
||||
|
||||
src_files = append_root(
|
||||
["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"]
|
||||
|
||||
@@ -21,8 +21,8 @@ from pathlib import Path
|
||||
FILES_TO_FIND = [
|
||||
"kernels/rwkv/wkv_cuda.cu",
|
||||
"kernels/rwkv/wkv_op.cpp",
|
||||
"models/deformable_detr/custom_kernel/ms_deform_attn.h",
|
||||
"models/deformable_detr/custom_kernel/cuda/ms_deform_im2col_cuda.cuh",
|
||||
"kernels/deformable_detr/ms_deform_attn.h",
|
||||
"kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh",
|
||||
"models/graphormer/algos_graphormer.pyx",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user