CPython基础使用

编译器: g++ (x86_64-win32-seh-rev0, Built by MinGW-W64 project) 8.1.0
Python环境: Python3.9

#define PY_SSIZE_T_CLEAN
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <numpy/arrayobject.h>

const int N = 256;
static npy_uint8 initlut[N * N * N][3];
// 注意类型 一开始写的int 结果一直找不到bug 
// 最后发现是数组的内存分布的问题 才发现这里类型错了

static PyObject* init(PyObject* self, PyObject* args) {
    if (!PyArg_ParseTuple(args, "")) {
        return NULL;
    }
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j < N; ++j) {
            for (int k = 0; k < N; ++k) {
                initlut[i*N*N + j*N + k][0] = i;
                initlut[i*N*N + j*N + k][1] = j;
                initlut[i*N*N + j*N + k][2] = k;
            }
        }
    }
    npy_intp dims[2] = {N * N * N, 3};
    PyObject* np_array = PyArray_SimpleNewFromData(
        2,
        dims,
        NPY_UINT8, 
    // 只是把initlut按照NPY_UINT8来解析 
    // 从下一个参数可以看出实际上传的就是个void* 不关心原本指向的类型
    // 因此原本的内存分布一定要和想要的类型对应起来
        (void*)initlut
    );
    // PyArray_SimpleNewFromData返回PyObject*指针,
    // 如果想指向PyArrayObject,必须先强转再赋值
    return np_array;
}

static PyObject* getimg(PyObject* self, PyObject* args) {
    PyArrayObject *img, *lut;
    
    if (!PyArg_ParseTuple(args, "O!O!", &PyArray_Type, &img, &PyArray_Type, &lut)) {
        return NULL;
    }

    if (PyArray_NDIM(img) != 3 || PyArray_DIM(img, 2) != 3) {
        PyErr_SetString(PyExc_ValueError, "img must be HxWx3 array");
        return NULL;
    }

    /* chech uint8 */
    if (PyArray_TYPE(img) != NPY_UINT8 || PyArray_TYPE(lut) != NPY_UINT8) {
        PyErr_SetString(PyExc_TypeError, "Arrays must be uint8");
        return NULL;
    }

    npy_intp H = PyArray_DIM(img, 0);
    npy_intp W = PyArray_DIM(img, 1);

    npy_intp out_dims[3] = {H, W, 3};
    PyArrayObject* newimg = (PyArrayObject*)PyArray_SimpleNew(3, out_dims, NPY_UINT8);
    if (!newimg) {
        return NULL;
    }

    npy_uint8* img_data = (npy_uint8*)PyArray_DATA(img);
    npy_uint8* lut_data = (npy_uint8*)PyArray_DATA(lut);
    npy_uint8* newimg_data = (npy_uint8*)PyArray_DATA(newimg);
    // 然后就可以按照普通的指针来操作了
    // npy_uint8 (*lut_data)[3] = (npy_uint8(*)[3])PyArray_DATA(lut); // 然后按二维数组读取 直接这样做有问题

    for (npy_intp h = 0; h < H; h++) {
        for (npy_intp w = 0; w < W; w++) {
            npy_uint8 r = img_data[h * W * 3 + w * 3];
            npy_uint8 g = img_data[h * W * 3 + w * 3 + 1];
            npy_uint8 b = img_data[h * W * 3 + w * 3 + 2];
            npy_intp idx = r * (N * N) + g * N + b;
            for (npy_intp c = 0; c < 3; ++c) {
                newimg_data[h * W * 3 + w * 3 + c] = lut_data[idx * 3 + c];
            }
        }
    }

    return (PyObject*)newimg;
}

static PyMethodDef methods[] = {
    {"init", init, METH_VARARGS, "init lut"},
    {"getimg", getimg, METH_VARARGS, "return lut(img)"},
    {NULL, NULL, 0, NULL}
};

static struct PyModuleDef module = {
    PyModuleDef_HEAD_INIT,
    "luts",
    NULL,
    -1,
    methods
};

PyMODINIT_FUNC PyInit_luts(void) {
    import_array();
    return PyModule_Create(&module);
}
posted @ 2025-05-23 13:52  TimeLimit  阅读(18)  评论(0)    收藏  举报