0%

pytorch扩展C

c++测试文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
//#include <torch/script.h>
#include <torch/extension.h>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc.hpp>

using namespace cv;
using namespace std;

torch::Tensor extend_gray(torch::Tensor image, torch::Tensor warp) {
// BEGIN image_mat
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data_ptr<float>());
// END image_mat

// BEGIN warp_mat
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data_ptr<float>());
// END warp_mat

// BEGIN output_mat
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});
// END output_mat

// BEGIN output_tensor
torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
return output.clone();
// END output_tensor
}

// pybind11 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("extend_gray", &extend_gray, "extend gray");
}

同目录下的 setup.py文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension

# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__)) ## 头文件路径
# 源代码目录
source_cpu = glob.glob(os.path.join(include_dirs, '*.cpp')) ## cpp文件列表 ['', '']

setup(
name='extend_gray', # 模块名称,需要在python中调用
version="0.1",
ext_modules=[
CppExtension('extend_gray',
sources=source_cpu,
library_dirs=['/usr/local/lib'], # '/usr/local/lib'为opencv的动态库路径 -L/usr/...
include_dirs=[include_dirs, '/usr/local/include/opencv', '/usr/local/include'], ## 头文件路径后两个为 opencv的
libraries=["opencv_core", "opencv_imgproc", "opencv_imgcodecs", "opencv_highgui"]),
## 相当于g++ 编译的 -lopencv_core 一定要写全, 否则不写全编译不会报错但是运行会出错,必要时可以直接g++编译测试看用到哪些动态库
],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
}
)

测试文件

1
2
3
4
import extend_gray
import torch

print(extend_gray.extend_gray(torch.randn(32, 32), torch.rand(3, 3)))

​ 遇到提示 .so 库找不到 使用 find 命令在环境中搜,然后 添加到 LD_LIBRARY_PATH 变量中 写入 .bashrc 文件后面