完整代码@折腾笔记[50]-cuda的性能优化及显存访问安全措施

摘要

本文档包含 折腾笔记[50]-cuda的性能优化及显存访问安全措施 CudaSharp 项目优化后的完整源代码。


目录

  1. SafeMem 内存/显存安全库
  2. SIFT GPU 加速模块
  3. 图像操作模块
  4. CUDA Native 导出层
  5. C# 封装库
  6. 测试程序
  7. 构建脚本

1. SafeMem 内存/显存安全库

1.1 safemem.h

#ifndef SAFEMEM_H
#define SAFEMEM_H

#ifdef __cplusplus
extern "C" {
#endif

#include <stddef.h>

/* For C compilation, define minimal CUDA types */
#ifndef __cplusplus
#ifndef __CUDA_RUNTIME_H__
typedef enum cudaMemcpyKind {
    cudaMemcpyHostToHost = 0,
    cudaMemcpyHostToDevice = 1,
    cudaMemcpyDeviceToHost = 2,
    cudaMemcpyDeviceToDevice = 3,
    cudaMemcpyDefault = 4
} cudaMemcpyKind;
typedef int cudaError_t;
#define cudaSuccess 0
#endif
#endif

#ifdef __cplusplus
#include <cuda_runtime.h>
#endif

/* ============================================================
 * SafeMem - Host Memory Management
 * Wraps malloc/calloc/realloc/free with safety checks
 * ============================================================ */

#define SAFE_CANARY_VALUE 0xDEADBEEFCAFEBABEULL
#define SAFE_CANARY_SIZE  16
#define SAFE_FILL_FREED   0xFE
#define SAFE_FILL_NEW     0xCD

/* Allocation tracking entry */
typedef struct SafeAllocEntry {
    void* ptr;
    size_t size;
    const char* file;
    int line;
    int is_freed;
    struct SafeAllocEntry* next;
} SafeAllocEntry;

/* Initialize safemem subsystem */
void safe_mem_init(void);

/* Shutdown and report leaks */
void safe_mem_shutdown(void);

/* Core allocation functions */
void* safe_malloc_impl(size_t size, const char* file, int line);
void* safe_calloc_impl(size_t nmemb, size_t size, const char* file, int line);
void* safe_realloc_impl(void* ptr, size_t size, const char* file, int line);
void  safe_free_impl(void** ptr, const char* file, int line);

/* Check canary integrity */
int safe_mem_check(const void* ptr, const char* file, int line);

/* Print allocation statistics */
void safe_mem_stats(void);

/* Macros for automatic file/line tracking */
#define SAFE_MALLOC(size)       safe_malloc_impl(size, __FILE__, __LINE__)
#define SAFE_CALLOC(n, size)    safe_calloc_impl(n, size, __FILE__, __LINE__)
#define SAFE_REALLOC(ptr, size) safe_realloc_impl(ptr, size, __FILE__, __LINE__)
#define SAFE_FREE(ptr)          safe_free_impl((void**)&(ptr), __FILE__, __LINE__)
#define SAFE_CHECK(ptr)         safe_mem_check(ptr, __FILE__, __LINE__)

/* ============================================================
 * SafeMem - Device Memory Management
 * Wraps cudaMalloc/cudaFree with safety checks
 * ============================================================ */

/* Device allocation tracking */
typedef struct SafeDeviceAllocEntry {
    void* ptr;
    size_t size;
    const char* file;
    int line;
    int is_freed;
    struct SafeDeviceAllocEntry* next;
} SafeDeviceAllocEntry;

void safe_device_mem_init(void);
void safe_device_mem_shutdown(void);

cudaError_t safe_cudaMalloc_impl(void** devPtr, size_t size, const char* file, int line);
cudaError_t safe_cudaMallocHost_impl(void** ptr, size_t size, const char* file, int line);
cudaError_t safe_cudaFree_impl(void* devPtr, const char* file, int line);
cudaError_t safe_cudaFreeHost_impl(void* ptr, const char* file, int line);

cudaError_t safe_cudaMemcpy_impl(void* dst, const void* src, size_t count,
                                  cudaMemcpyKind kind, const char* file, int line);
cudaError_t safe_cudaMemset_impl(void* devPtr, int value, size_t count,
                                  const char* file, int line);

void safe_device_mem_stats(void);

#define SAFE_CUDA_MALLOC(devPtr, size) \
    safe_cudaMalloc_impl((void**)(devPtr), size, __FILE__, __LINE__)
#define SAFE_CUDA_MALLOC_HOST(ptr, size) \
    safe_cudaMallocHost_impl((void**)(ptr), size, __FILE__, __LINE__)
#define SAFE_CUDA_FREE(devPtr) \
    safe_cudaFree_impl(devPtr, __FILE__, __LINE__)
#define SAFE_CUDA_FREE_HOST(ptr) \
    safe_cudaFreeHost_impl(ptr, __FILE__, __LINE__)
#define SAFE_CUDA_MEMCPY(dst, src, count, kind) \
    safe_cudaMemcpy_impl(dst, src, count, kind, __FILE__, __LINE__)
#define SAFE_CUDA_MEMSET(devPtr, value, count) \
    safe_cudaMemset_impl(devPtr, value, count, __FILE__, __LINE__)

/* ============================================================
 * SafeMem - Memory Pool for Device
 * Eliminates fragmentation by pre-allocating large chunks
 * ============================================================ */

typedef struct DeviceMemPool {
    void* pool_base;          /* Base pointer of pre-allocated memory */
    size_t pool_size;         /* Total pool size */
    size_t used;              /* Currently used */
    size_t peak_used;         /* Peak usage */
    int num_allocs;           /* Number of active allocations */

    /* Simple bump allocator + free list */
    void* bump_ptr;           /* Current bump pointer */
    struct PoolFreeNode* free_list;
} DeviceMemPool;

/* Create a device memory pool */
DeviceMemPool* dev_pool_create(size_t initial_size);

/* Allocate from pool (thread-safe via CUDA context) */
void* dev_pool_malloc(DeviceMemPool* pool, size_t size);

/* Reset pool (free all allocations at once) */
void dev_pool_reset(DeviceMemPool* pool);

/* Destroy pool */
void dev_pool_destroy(DeviceMemPool* pool);

/* Get pool statistics */
void dev_pool_stats(DeviceMemPool* pool, size_t* total, size_t* used, size_t* peak);

#ifdef __cplusplus
}
#endif

#endif /* SAFEMEM_H */

1.2 safemem_host.c

#include "safemem.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>

/* ============================================================
 * SafeMem Host Implementation
 * ============================================================ */

static SafeAllocEntry* g_alloc_head = NULL;
static size_t g_total_allocated = 0;
static size_t g_total_freed = 0;
static size_t g_peak_allocated = 0;
static int g_initialized = 0;

void safe_mem_init(void) {
    if (g_initialized) return;
    g_alloc_head = NULL;
    g_total_allocated = 0;
    g_total_freed = 0;
    g_peak_allocated = 0;
    g_initialized = 1;
}

void safe_mem_shutdown(void) {
    if (!g_initialized) return;

    int leak_count = 0;
    size_t leak_bytes = 0;
    SafeAllocEntry* entry = g_alloc_head;

    while (entry) {
        if (!entry->is_freed) {
            leak_count++;
            leak_bytes += entry->size;
            fprintf(stderr, "[SafeMem] LEAK: %p, size=%zu, %s:%d\n",
                    entry->ptr, entry->size, entry->file, entry->line);
        }
        SafeAllocEntry* next = entry->next;
        free(entry);
        entry = next;
    }

    if (leak_count > 0) {
        fprintf(stderr, "[SafeMem] TOTAL LEAKS: %d allocations, %zu bytes\n",
                leak_count, leak_bytes);
    } else {
        fprintf(stderr, "[SafeMem] No memory leaks detected.\n");
    }

    g_alloc_head = NULL;
    g_initialized = 0;
}

static void add_alloc_entry(void* ptr, size_t size, const char* file, int line) {
    SafeAllocEntry* entry = (SafeAllocEntry*)malloc(sizeof(SafeAllocEntry));
    if (!entry) {
        fprintf(stderr, "[SafeMem] FATAL: Failed to allocate tracking entry\n");
        exit(1);
    }
    entry->ptr = ptr;
    entry->size = size;
    entry->file = file;
    entry->line = line;
    entry->is_freed = 0;
    entry->next = g_alloc_head;
    g_alloc_head = entry;

    g_total_allocated += size;
    size_t current = g_total_allocated - g_total_freed;
    if (current > g_peak_allocated) g_peak_allocated = current;
}

static SafeAllocEntry* find_entry(void* ptr) {
    SafeAllocEntry* entry = g_alloc_head;
    while (entry) {
        if (entry->ptr == ptr) return entry;
        entry = entry->next;
    }
    return NULL;
}

static void write_canary(void* ptr, size_t size) {
    uint64_t* pre = (uint64_t*)((char*)ptr - SAFE_CANARY_SIZE);
    uint64_t* post = (uint64_t*)((char*)ptr + size);
    pre[0] = SAFE_CANARY_VALUE;
    pre[1] = SAFE_CANARY_VALUE;
    post[0] = SAFE_CANARY_VALUE;
    post[1] = SAFE_CANARY_VALUE;
}

static int check_canary(const void* ptr, size_t size) {
    const uint64_t* pre = (const uint64_t*)((const char*)ptr - SAFE_CANARY_SIZE);
    const uint64_t* post = (const uint64_t*)((const char*)ptr + size);
    return (pre[0] == SAFE_CANARY_VALUE && pre[1] == SAFE_CANARY_VALUE &&
            post[0] == SAFE_CANARY_VALUE && post[1] == SAFE_CANARY_VALUE);
}

void* safe_malloc_impl(size_t size, const char* file, int line) {
    if (!g_initialized) safe_mem_init();
    if (size == 0) {
        fprintf(stderr, "[SafeMem] WARN: malloc(0) at %s:%d\n", file, line);
        return NULL;
    }

    void* raw = malloc(size + 2 * SAFE_CANARY_SIZE);
    if (!raw) {
        fprintf(stderr, "[SafeMem] FATAL: malloc(%zu) failed at %s:%d\n", size, file, line);
        exit(1);
    }

    void* ptr = (char*)raw + SAFE_CANARY_SIZE;
    memset(ptr, SAFE_FILL_NEW, size);
    write_canary(ptr, size);
    add_alloc_entry(ptr, size, file, line);

    return ptr;
}

void* safe_calloc_impl(size_t nmemb, size_t size, const char* file, int line) {
    if (!g_initialized) safe_mem_init();
    if (nmemb == 0 || size == 0) {
        fprintf(stderr, "[SafeMem] WARN: calloc(0) at %s:%d\n", file, line);
        return NULL;
    }

    size_t total = nmemb * size;
    void* raw = calloc(1, total + 2 * SAFE_CANARY_SIZE);
    if (!raw) {
        fprintf(stderr, "[SafeMem] FATAL: calloc(%zu, %zu) failed at %s:%d\n", nmemb, size, file, line);
        exit(1);
    }

    void* ptr = (char*)raw + SAFE_CANARY_SIZE;
    write_canary(ptr, total);
    add_alloc_entry(ptr, total, file, line);

    return ptr;
}

void* safe_realloc_impl(void* ptr, size_t size, const char* file, int line) {
    if (!g_initialized) safe_mem_init();
    if (!ptr) return safe_malloc_impl(size, file, line);
    if (size == 0) {
        safe_free_impl(&ptr, file, line);
        return NULL;
    }

    SafeAllocEntry* entry = find_entry(ptr);
    if (!entry) {
        fprintf(stderr, "[SafeMem] ERROR: realloc of untracked pointer %p at %s:%d\n", ptr, file, line);
        exit(1);
    }

    if (!check_canary(ptr, entry->size)) {
        fprintf(stderr, "[SafeMem] ERROR: Canary corruption detected before realloc %p at %s:%d\n", ptr, file, line);
        exit(1);
    }

    void* raw = realloc((char*)ptr - SAFE_CANARY_SIZE, size + 2 * SAFE_CANARY_SIZE);
    if (!raw) {
        fprintf(stderr, "[SafeMem] FATAL: realloc(%p, %zu) failed at %s:%d\n", ptr, size, file, line);
        exit(1);
    }

    void* new_ptr = (char*)raw + SAFE_CANARY_SIZE;
    entry->ptr = new_ptr;
    g_total_freed += entry->size;
    g_total_allocated += size;
    entry->size = size;
    write_canary(new_ptr, size);

    size_t current = g_total_allocated - g_total_freed;
    if (current > g_peak_allocated) g_peak_allocated = current;

    return new_ptr;
}

void safe_free_impl(void** ptr, const char* file, int line) {
    if (!g_initialized) safe_mem_init();
    if (!ptr || !*ptr) return;

    void* p = *ptr;
    SafeAllocEntry* entry = find_entry(p);
    if (!entry) {
        fprintf(stderr, "[SafeMem] ERROR: free of untracked pointer %p at %s:%d\n", p, file, line);
        exit(1);
    }

    if (entry->is_freed) {
        fprintf(stderr, "[SafeMem] ERROR: Double-free detected %p at %s:%d (originally %s:%d)\n",
                p, file, line, entry->file, entry->line);
        exit(1);
    }

    if (!check_canary(p, entry->size)) {
        fprintf(stderr, "[SafeMem] ERROR: Canary corruption detected before free %p at %s:%d\n", p, file, line);
        exit(1);
    }

    memset((char*)p - SAFE_CANARY_SIZE, SAFE_FILL_FREED, entry->size + 2 * SAFE_CANARY_SIZE);
    entry->is_freed = 1;
    g_total_freed += entry->size;
    free((char*)p - SAFE_CANARY_SIZE);
    *ptr = NULL;
}

int safe_mem_check(const void* ptr, const char* file, int line) {
    if (!ptr) return 0;
    SafeAllocEntry* entry = find_entry((void*)ptr);
    if (!entry) {
        fprintf(stderr, "[SafeMem] WARN: Check of untracked pointer %p at %s:%d\n", ptr, file, line);
        return 0;
    }
    if (entry->is_freed) {
        fprintf(stderr, "[SafeMem] ERROR: Use-after-free detected %p at %s:%d\n", ptr, file, line);
        return 0;
    }
    if (!check_canary(ptr, entry->size)) {
        fprintf(stderr, "[SafeMem] ERROR: Canary corruption %p at %s:%d\n", ptr, file, line);
        return 0;
    }
    return 1;
}

void safe_mem_stats(void) {
    if (!g_initialized) {
        printf("[SafeMem] Not initialized.\n");
        return;
    }
    size_t current = g_total_allocated - g_total_freed;
    printf("[SafeMem] Stats: total=%zu, freed=%zu, current=%zu, peak=%zu\n",
           g_total_allocated, g_total_freed, current, g_peak_allocated);
}

1.3 safemem_device.cu

#include "safemem.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* ============================================================
 * SafeMem Device Implementation
 * ============================================================ */

static SafeDeviceAllocEntry* g_device_alloc_head = NULL;
static size_t g_device_total_allocated = 0;
static size_t g_device_total_freed = 0;
static size_t g_device_peak_allocated = 0;
static int g_device_initialized = 0;

void safe_device_mem_init(void) {
    if (g_device_initialized) return;
    g_device_alloc_head = NULL;
    g_device_total_allocated = 0;
    g_device_total_freed = 0;
    g_device_peak_allocated = 0;
    g_device_initialized = 1;
}

void safe_device_mem_shutdown(void) {
    if (!g_device_initialized) return;

    int leak_count = 0;
    size_t leak_bytes = 0;
    SafeDeviceAllocEntry* entry = g_device_alloc_head;

    while (entry) {
        if (!entry->is_freed) {
            leak_count++;
            leak_bytes += entry->size;
            fprintf(stderr, "[SafeMem Device] LEAK: %p, size=%zu, %s:%d\n",
                    entry->ptr, entry->size, entry->file, entry->line);
            cudaFree(entry->ptr);
        }
        SafeDeviceAllocEntry* next = entry->next;
        free(entry);
        entry = next;
    }

    if (leak_count > 0) {
        fprintf(stderr, "[SafeMem Device] TOTAL LEAKS: %d allocations, %zu bytes\n",
                leak_count, leak_bytes);
    } else {
        fprintf(stderr, "[SafeMem Device] No device memory leaks detected.\n");
    }

    g_device_alloc_head = NULL;
    g_device_initialized = 0;
}

static void add_device_entry(void* ptr, size_t size, const char* file, int line) {
    SafeDeviceAllocEntry* entry = (SafeDeviceAllocEntry*)malloc(sizeof(SafeDeviceAllocEntry));
    if (!entry) {
        fprintf(stderr, "[SafeMem Device] FATAL: Failed to allocate tracking entry\n");
        exit(1);
    }
    entry->ptr = ptr;
    entry->size = size;
    entry->file = file;
    entry->line = line;
    entry->is_freed = 0;
    entry->next = g_device_alloc_head;
    g_device_alloc_head = entry;

    g_device_total_allocated += size;
    size_t current = g_device_total_allocated - g_device_total_freed;
    if (current > g_device_peak_allocated) g_device_peak_allocated = current;
}

static SafeDeviceAllocEntry* find_device_entry(void* ptr) {
    SafeDeviceAllocEntry* entry = g_device_alloc_head;
    while (entry) {
        if (entry->ptr == ptr) return entry;
        entry = entry->next;
    }
    return NULL;
}

cudaError_t safe_cudaMalloc_impl(void** devPtr, size_t size, const char* file, int line) {
    if (!g_device_initialized) safe_device_mem_init();
    if (!devPtr) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMalloc NULL devPtr at %s:%d\n", file, line);
        return cudaErrorInvalidValue;
    }
    if (size == 0) {
        fprintf(stderr, "[SafeMem Device] WARN: cudaMalloc(0) at %s:%d\n", file, line);
        *devPtr = NULL;
        return cudaSuccess;
    }

    cudaError_t err = cudaMalloc(devPtr, size);
    if (err != cudaSuccess) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMalloc(%zu) failed at %s:%d: %s\n",
                size, file, line, cudaGetErrorString(err));
        return err;
    }

    add_device_entry(*devPtr, size, file, line);
    return cudaSuccess;
}

cudaError_t safe_cudaMallocHost_impl(void** ptr, size_t size, const char* file, int line) {
    if (!g_device_initialized) safe_device_mem_init();
    if (!ptr) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMallocHost NULL ptr at %s:%d\n", file, line);
        return cudaErrorInvalidValue;
    }

    cudaError_t err = cudaMallocHost(ptr, size);
    if (err != cudaSuccess) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMallocHost(%zu) failed at %s:%d: %s\n",
                size, file, line, cudaGetErrorString(err));
        return err;
    }

    add_device_entry(*ptr, size, file, line);
    return cudaSuccess;
}

cudaError_t safe_cudaFree_impl(void* devPtr, const char* file, int line) {
    if (!g_device_initialized) safe_device_mem_init();
    if (!devPtr) return cudaSuccess;

    SafeDeviceAllocEntry* entry = find_device_entry(devPtr);
    if (!entry) {
        fprintf(stderr, "[SafeMem Device] WARN: cudaFree of untracked pointer %p at %s:%d\n",
                devPtr, file, line);
        return cudaFree(devPtr);
    }

    if (entry->is_freed) {
        fprintf(stderr, "[SafeMem Device] ERROR: Double cudaFree detected %p at %s:%d (originally %s:%d)\n",
                devPtr, file, line, entry->file, entry->line);
        return cudaErrorInvalidDevicePointer;
    }

    entry->is_freed = 1;
    g_device_total_freed += entry->size;
    return cudaFree(devPtr);
}

cudaError_t safe_cudaFreeHost_impl(void* ptr, const char* file, int line) {
    if (!g_device_initialized) safe_device_mem_init();
    if (!ptr) return cudaSuccess;

    SafeDeviceAllocEntry* entry = find_device_entry(ptr);
    if (!entry) {
        fprintf(stderr, "[SafeMem Device] WARN: cudaFreeHost of untracked pointer %p at %s:%d\n",
                ptr, file, line);
        return cudaFreeHost(ptr);
    }

    if (entry->is_freed) {
        fprintf(stderr, "[SafeMem Device] ERROR: Double cudaFreeHost detected %p at %s:%d\n",
                ptr, file, line);
        return cudaErrorInvalidValue;
    }

    entry->is_freed = 1;
    g_device_total_freed += entry->size;
    return cudaFreeHost(ptr);
}

cudaError_t safe_cudaMemcpy_impl(void* dst, const void* src, size_t count,
                                  cudaMemcpyKind kind, const char* file, int line) {
    if (!dst || !src) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMemcpy NULL ptr at %s:%d\n", file, line);
        return cudaErrorInvalidValue;
    }
    if (count == 0) return cudaSuccess;

    cudaError_t err = cudaMemcpy(dst, src, count, kind);
    if (err != cudaSuccess) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMemcpy(%zu) failed at %s:%d: %s\n",
                count, file, line, cudaGetErrorString(err));
    }
    return err;
}

cudaError_t safe_cudaMemset_impl(void* devPtr, int value, size_t count,
                                  const char* file, int line) {
    if (!devPtr) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMemset NULL ptr at %s:%d\n", file, line);
        return cudaErrorInvalidValue;
    }

    cudaError_t err = cudaMemset(devPtr, value, count);
    if (err != cudaSuccess) {
        fprintf(stderr, "[SafeMem Device] ERROR: cudaMemset failed at %s:%d: %s\n",
                file, line, cudaGetErrorString(err));
    }
    return err;
}

void safe_device_mem_stats(void) {
    if (!g_device_initialized) {
        printf("[SafeMem Device] Not initialized.\n");
        return;
    }
    size_t current = g_device_total_allocated - g_device_total_freed;
    printf("[SafeMem Device] Stats: total=%zu, freed=%zu, current=%zu, peak=%zu\n",
           g_device_total_allocated, g_device_total_freed, current, g_device_peak_allocated);
}

### 1.4 safemem_pool.cu

#include "safemem.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* ============================================================
 * SafeMem Device Memory Pool
 * Bump allocator with simple free list
 * ============================================================ */

typedef struct PoolFreeNode {
    void* ptr;
    size_t size;
    struct PoolFreeNode* next;
} PoolFreeNode;

DeviceMemPool* dev_pool_create(size_t initial_size) {
    if (initial_size == 0) initial_size = 64 * 1024 * 1024; /* 64MB default */

    DeviceMemPool* pool = (DeviceMemPool*)malloc(sizeof(DeviceMemPool));
    if (!pool) {
        fprintf(stderr, "[DevPool] FATAL: Failed to allocate pool struct\n");
        return NULL;
    }

    cudaError_t err = cudaMalloc(&pool->pool_base, initial_size);
    if (err != cudaSuccess) {
        fprintf(stderr, "[DevPool] FATAL: cudaMalloc(%zu) failed: %s\n",
                initial_size, cudaGetErrorString(err));
        free(pool);
        return NULL;
    }

    pool->pool_size = initial_size;
    pool->used = 0;
    pool->peak_used = 0;
    pool->num_allocs = 0;
    pool->bump_ptr = pool->pool_base;
    pool->free_list = NULL;

    printf("[DevPool] Created pool: %zu MB at %p\n", initial_size / (1024*1024), pool->pool_base);
    return pool;
}

static void* try_free_list(DeviceMemPool* pool, size_t size) {
    PoolFreeNode** curr = &pool->free_list;
    while (*curr) {
        if ((*curr)->size >= size) {
            void* ptr = (*curr)->ptr;
            PoolFreeNode* to_remove = *curr;
            *curr = (*curr)->next;
            free(to_remove);
            pool->num_allocs++;
            return ptr;
        }
        curr = &(*curr)->next;
    }
    return NULL;
}

void* dev_pool_malloc(DeviceMemPool* pool, size_t size) {
    if (!pool || size == 0) return NULL;

    /* Align to 256 bytes for CUDA */
    size_t aligned_size = (size + 255) & ~255;

    /* Try free list first */
    void* ptr = try_free_list(pool, aligned_size);
    if (ptr) return ptr;

    /* Bump allocate */
    char* bump = (char*)pool->bump_ptr;
    if ((size_t)(bump - (char*)pool->pool_base) + aligned_size > pool->pool_size) {
        fprintf(stderr, "[DevPool] ERROR: Out of memory (requested %zu, available %zu)\n",
                aligned_size, pool->pool_size - pool->used);
        return NULL;
    }

    ptr = bump;
    pool->bump_ptr = bump + aligned_size;
    pool->used += aligned_size;
    pool->num_allocs++;

    if (pool->used > pool->peak_used) pool->peak_used = pool->used;

    return ptr;
}

void dev_pool_reset(DeviceMemPool* pool) {
    if (!pool) return;

    /* Free all free list nodes */
    PoolFreeNode* node = pool->free_list;
    while (node) {
        PoolFreeNode* next = node->next;
        free(node);
        node = next;
    }
    pool->free_list = NULL;

    pool->bump_ptr = pool->pool_base;
    pool->used = 0;
    pool->num_allocs = 0;

    printf("[DevPool] Pool reset\n");
}

void dev_pool_destroy(DeviceMemPool* pool) {
    if (!pool) return;

    dev_pool_reset(pool);

    if (pool->pool_base) {
        cudaFree(pool->pool_base);
    }

    printf("[DevPool] Destroyed pool (peak usage: %zu MB / %zu MB)\n",
           pool->peak_used / (1024*1024), pool->pool_size / (1024*1024));
    free(pool);
}

void dev_pool_stats(DeviceMemPool* pool, size_t* total, size_t* used, size_t* peak) {
    if (!pool) return;
    if (total) *total = pool->pool_size;
    if (used) *used = pool->used;
    if (peak) *peak = pool->peak_used;
}

2. SIFT GPU 加速模块

2.1 sift_detect.h

#ifndef SIFT_DETECT_H
#define SIFT_DETECT_H

#ifdef __cplusplus
extern "C" {
#endif

#include <cuda_runtime.h>

// Device-side feature structure
typedef struct {
    float x, y;
    float scl;
    float ori;
    float descr[128];
    int r, c;
    int octv, intvl;
    float subintvl;
    float scl_octv;
    int img_width, img_height;
    int d;
    float img_pt_x, img_pt_y;
} FeatureDevice;

// GPU 极值检测 + 亚像素插值 + 边缘过滤
int sift_detect_extrema_gpu(
    const float* const* d_dog_pyr,
    const int* d_widths,
    const int* d_heights,
    int octvs, int intvls,
    float contr_thr, float curv_thr,
    FeatureDevice** d_out_features,
    int* h_out_count
);

// GPU 方向分配
int sift_calc_oris_gpu(
    const float* const* d_gauss_pyr,
    const int* d_widths,
    const int* d_heights,
    FeatureDevice* d_features,
    int num_features,
    FeatureDevice** d_out_features,
    int* h_out_count
);

// GPU 描述子生成
int sift_compute_descriptors_gpu(
    const float* const* d_gauss_pyr,
    const int* d_widths,
    const int* d_heights,
    FeatureDevice* d_features,
    int num_features,
    int d, int n
);

#ifdef __cplusplus
}
#endif

#endif

2.2 sift_detect.cu

#include "sift_detect.h"
#include "safemem/safemem.h"
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
#include <stdio.h>
#include <stdlib.h>

#ifndef M_PI
#define M_PI 3.14159265358979323846f
#endif

#define SIFT_IMG_BORDER 5
#define SIFT_MAX_INTERP_STEPS 5
#define SIFT_ORI_HIST_BINS 36
#define SIFT_ORI_SIG_FCTR 1.5f
#define SIFT_ORI_RADIUS (3.0f * SIFT_ORI_SIG_FCTR)
#define SIFT_ORI_SMOOTH_PASSES 2
#define SIFT_ORI_PEAK_RATIO 0.8f
#define SIFT_DESCR_SCL_FCTR 3.0f
#define SIFT_DESCR_MAG_THR 0.2f
#define SIFT_INT_DESCR_FCTR 512.0f
#define SIFT_DESCR_WIDTH 4
#define SIFT_DESCR_HIST_BINS 8

// ===== GPU 极值检测 =====

__global__ void detect_extrema_kernel(
    const float* const* dog_pyr,
    const int* widths,
    const int* heights,
    int octv,
    int intvl,
    float prelim_contr_thr,
    unsigned char* extremum_mask
) {
    int c = blockIdx.x * blockDim.x + threadIdx.x;
    int r = blockIdx.y * blockDim.y + threadIdx.y;
    int w = widths[octv];
    int h = heights[octv];

    if (c < SIFT_IMG_BORDER || c >= w - SIFT_IMG_BORDER ||
        r < SIFT_IMG_BORDER || r >= h - SIFT_IMG_BORDER) return;

    int idx = r * w + c;
    float val = dog_pyr[octv * 5 + intvl][idx];

    if (fabsf(val) <= prelim_contr_thr) return;

    bool is_max = true;
    bool is_min = true;

    #pragma unroll
    for (int di = -1; di <= 1 && (is_max || is_min); di++) {
        for (int dr = -1; dr <= 1 && (is_max || is_min); dr++) {
            for (int dc = -1; dc <= 1 && (is_max || is_min); dc++) {
                if (di == 0 && dr == 0 && dc == 0) continue;
                int nc = c + dc;
                int nr = r + dr;
                int nidx = nr * w + nc;
                float neighbor = dog_pyr[octv * 5 + intvl + di][nidx];
                if (val < neighbor) is_max = false;
                if (val > neighbor) is_min = false;
            }
        }
    }

    extremum_mask[idx] = (is_max || is_min) ? 1 : 0;
}

// ===== GPU 亚像素插值 + 边缘过滤 =====

__device__ float d_gray_image_get(const float* data, int w, int r, int c) {
    return data[r * w + c];
}

__device__ void d_deriv_3D(const float* dog0, const float* dog1, const float* dog2,
                           int w, int r, int c, float dI[3]) {
    dI[0] = (d_gray_image_get(dog1, w, r, c + 1) - d_gray_image_get(dog1, w, r, c - 1)) * 0.5f;
    dI[1] = (d_gray_image_get(dog1, w, r + 1, c) - d_gray_image_get(dog1, w, r - 1, c)) * 0.5f;
    dI[2] = (d_gray_image_get(dog2, w, r, c) - d_gray_image_get(dog0, w, r, c)) * 0.5f;
}

__device__ void d_hessian_3D(const float* dog0, const float* dog1, const float* dog2,
                             int w, int r, int c, float H[3][3]) {
    float v = d_gray_image_get(dog1, w, r, c);
    float dxx = d_gray_image_get(dog1, w, r, c + 1) + d_gray_image_get(dog1, w, r, c - 1) - 2.0f * v;
    float dyy = d_gray_image_get(dog1, w, r + 1, c) + d_gray_image_get(dog1, w, r - 1, c) - 2.0f * v;
    float dss = d_gray_image_get(dog2, w, r, c) + d_gray_image_get(dog0, w, r, c) - 2.0f * v;
    float dxy = (d_gray_image_get(dog1, w, r + 1, c + 1) - d_gray_image_get(dog1, w, r + 1, c - 1)
               - d_gray_image_get(dog1, w, r - 1, c + 1) + d_gray_image_get(dog1, w, r - 1, c - 1)) * 0.25f;
    float dxs = (d_gray_image_get(dog2, w, r, c + 1) - d_gray_image_get(dog2, w, r, c - 1)
               - d_gray_image_get(dog0, w, r, c + 1) + d_gray_image_get(dog0, w, r, c - 1)) * 0.25f;
    float dys = (d_gray_image_get(dog2, w, r + 1, c) - d_gray_image_get(dog2, w, r - 1, c)
               - d_gray_image_get(dog0, w, r + 1, c) + d_gray_image_get(dog0, w, r - 1, c)) * 0.25f;

    H[0][0] = dxx; H[0][1] = dxy; H[0][2] = dxs;
    H[1][0] = dxy; H[1][1] = dyy; H[1][2] = dys;
    H[2][0] = dxs; H[2][1] = dys; H[2][2] = dss;
}

__device__ int d_invert3x3(float H[3][3], float H_inv[3][3]) {
    float det = H[0][0]*(H[1][1]*H[2][2]-H[1][2]*H[2][1])
              - H[0][1]*(H[1][0]*H[2][2]-H[1][2]*H[2][0])
              + H[0][2]*(H[1][0]*H[2][1]-H[1][1]*H[2][0]);
    if (fabsf(det) < 1e-12f) return 0;
    float inv_det = 1.0f / det;

    H_inv[0][0] = (H[1][1]*H[2][2]-H[1][2]*H[2][1])*inv_det;
    H_inv[0][1] = (H[0][2]*H[2][1]-H[0][1]*H[2][2])*inv_det;
    H_inv[0][2] = (H[0][1]*H[1][2]-H[0][2]*H[1][1])*inv_det;
    H_inv[1][0] = (H[1][2]*H[2][0]-H[1][0]*H[2][2])*inv_det;
    H_inv[1][1] = (H[0][0]*H[2][2]-H[0][2]*H[2][0])*inv_det;
    H_inv[1][2] = (H[0][2]*H[1][0]-H[0][0]*H[1][2])*inv_det;
    H_inv[2][0] = (H[1][0]*H[2][1]-H[1][1]*H[2][0])*inv_det;
    H_inv[2][1] = (H[0][1]*H[2][0]-H[0][0]*H[2][1])*inv_det;
    H_inv[2][2] = (H[0][0]*H[1][1]-H[0][1]*H[1][0])*inv_det;
    return 1;
}

__device__ int d_is_too_edge_like(const float* dog_img, int w, int h, int r, int c, float curv_thr) {
    if (c <= 0 || r <= 0 || c >= w - 1 || r >= h - 1) return 1;
    float d = d_gray_image_get(dog_img, w, r, c);
    float dxx = d_gray_image_get(dog_img, w, r, c + 1) + d_gray_image_get(dog_img, w, r, c - 1) - 2.0f * d;
    float dyy = d_gray_image_get(dog_img, w, r + 1, c) + d_gray_image_get(dog_img, w, r - 1, c) - 2.0f * d;
    float dxy = (d_gray_image_get(dog_img, w, r + 1, c + 1) - d_gray_image_get(dog_img, w, r + 1, c - 1)
               - d_gray_image_get(dog_img, w, r - 1, c + 1) + d_gray_image_get(dog_img, w, r - 1, c - 1)) * 0.25f;
    float tr = dxx + dyy;
    float det = dxx * dyy - dxy * dxy;
    if (det <= 0) return 1;
    float ratio = tr * tr / det;
    float thresh = (curv_thr + 1.0f) * (curv_thr + 1.0f) / curv_thr;
    return (ratio < thresh) ? 0 : 1;
}

__global__ void refine_features_kernel(
    const float* const* dog_pyr,
    const int* widths,
    const int* heights,
    int octv,
    int intvl,
    int intvls,
    float contr_thr,
    float curv_thr,
    const unsigned char* extremum_mask,
    FeatureDevice* out_features,
    int* out_count,
    int max_features
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int w = widths[octv];
    int h = heights[octv];
    int total = w * h;

    if (idx >= total || !extremum_mask[idx]) return;

    int c = idx % w;
    int r = idx / w;

    float xi = 0, xr = 0, xc = 0;
    int curr_r = r, curr_c = c, curr_intvl = intvl;
    int i = 0;

    const float* dog0 = dog_pyr[octv * 5 + curr_intvl - 1];
    const float* dog1 = dog_pyr[octv * 5 + curr_intvl];
    const float* dog2 = dog_pyr[octv * 5 + curr_intvl + 1];

    while (i < SIFT_MAX_INTERP_STEPS) {
        float dI[3];
        d_deriv_3D(dog0, dog1, dog2, w, curr_r, curr_c, dI);
        float H[3][3];
        d_hessian_3D(dog0, dog1, dog2, w, curr_r, curr_c, H);
        float H_inv[3][3];
        if (!d_invert3x3(H, H_inv)) break;

        float x[3] = {0, 0, 0};
        for (int j = 0; j < 3; j++)
            for (int k = 0; k < 3; k++)
                x[j] -= H_inv[j][k] * dI[k];

        xc = x[0]; xr = x[1]; xi = x[2];

        if (fabsf(xi) < 0.5f && fabsf(xr) < 0.5f && fabsf(xc) < 0.5f) break;

        curr_c += (int)roundf(xc);
        curr_r += (int)roundf(xr);
        curr_intvl += (int)roundf(xi);

        if (curr_intvl < 1 || curr_intvl > intvls ||
            curr_c < SIFT_IMG_BORDER || curr_r < SIFT_IMG_BORDER ||
            curr_c >= w - SIFT_IMG_BORDER || curr_r >= h - SIFT_IMG_BORDER) {
            return;
        }

        dog0 = dog_pyr[octv * 5 + curr_intvl - 1];
        dog1 = dog_pyr[octv * 5 + curr_intvl];
        dog2 = dog_pyr[octv * 5 + curr_intvl + 1];
        i++;
    }

    if (i >= SIFT_MAX_INTERP_STEPS) return;

    float dI[3];
    d_deriv_3D(dog0, dog1, dog2, w, curr_r, curr_c, dI);
    float t = dI[0] * xc + dI[1] * xr + dI[2] * xi;
    float contr = d_gray_image_get(dog1, w, curr_r, curr_c) + t * 0.5f;
    if (fabsf(contr) < contr_thr / intvls) return;

    if (d_is_too_edge_like(dog1, w, h, curr_r, curr_c, curv_thr)) return;

    int pos = atomicAdd(out_count, 1);
    if (pos >= max_features) {
        atomicSub(out_count, 1);
        return;
    }

    FeatureDevice feat;
    feat.x = (curr_c + xc) * powf(2.0f, octv);
    feat.y = (curr_r + xr) * powf(2.0f, octv);
    feat.img_pt_x = feat.x;
    feat.img_pt_y = feat.y;
    feat.r = curr_r;
    feat.c = curr_c;
    feat.octv = octv;
    feat.intvl = curr_intvl;
    feat.subintvl = xi;
    feat.scl_octv = 0;
    feat.scl = 0;
    feat.ori = 0;
    feat.d = 0;
    feat.img_width = w;
    feat.img_height = h;
    out_features[pos] = feat;
}

// ===== GPU 方向分配 =====

__global__ void calc_ori_kernel(
    const float* const* gauss_pyr,
    const int* widths,
    const int* heights,
    FeatureDevice* features,
    int num_features,
    int* out_num_features,
    FeatureDevice* out_features,
    int max_out_features
) {
    int feat_idx = blockIdx.x;
    if (feat_idx >= num_features) return;

    FeatureDevice feat = features[feat_idx];
    int w = widths[feat.octv];
    int h = heights[feat.octv];
    int r = feat.r;
    int c = feat.c;
    float scl = feat.scl_octv;

    int rad = (int)roundf(SIFT_ORI_RADIUS * scl);
    float sigma = SIFT_ORI_SIG_FCTR * scl;
    float exp_denom = 2.0f * sigma * sigma;
    float PI2 = 2.0f * M_PI;

    __shared__ float s_hist[SIFT_ORI_HIST_BINS];
    if (threadIdx.x < SIFT_ORI_HIST_BINS) s_hist[threadIdx.x] = 0.0f;
    __syncthreads();

    int area = (2 * rad + 1) * (2 * rad + 1);
    const float* img = gauss_pyr[feat.octv * 6 + feat.intvl];

    for (int i = threadIdx.x; i < area; i += blockDim.x) {
        int dy = i / (2 * rad + 1) - rad;
        int dx = i % (2 * rad + 1) - rad;
        int pr = r + dy;
        int pc = c + dx;

        if (pr <= 0 || pr >= h - 1 || pc <= 0 || pc >= w - 1) continue;

        float dx_val = img[pr * w + pc + 1] - img[pr * w + pc - 1];
        float dy_val = img[(pr - 1) * w + pc] - img[(pr + 1) * w + pc];
        float mag = sqrtf(dx_val * dx_val + dy_val * dy_val);
        float ori = atan2f(dy_val, dx_val);

        float weight = expf(-(dx * dx + dy * dy) / exp_denom);
        int bin = (int)roundf(SIFT_ORI_HIST_BINS * (ori + M_PI) / PI2);
        bin = (bin < SIFT_ORI_HIST_BINS) ? bin : 0;

        atomicAdd(&s_hist[bin], weight * mag);
    }
    __syncthreads();

    if (threadIdx.x != 0) return;

    for (int s = 0; s < SIFT_ORI_SMOOTH_PASSES; s++) {
        float h0 = s_hist[0];
        float prev = s_hist[SIFT_ORI_HIST_BINS - 1];
        for (int j = 0; j < SIFT_ORI_HIST_BINS; j++) {
            float tmp = s_hist[j];
            float next = (j + 1 == SIFT_ORI_HIST_BINS) ? h0 : s_hist[j + 1];
            s_hist[j] = 0.25f * prev + 0.5f * s_hist[j] + 0.25f * next;
            prev = tmp;
        }
    }

    float omax = 0;
    for (int j = 0; j < SIFT_ORI_HIST_BINS; j++)
        if (s_hist[j] > omax) omax = s_hist[j];

    float mag_thr = omax * SIFT_ORI_PEAK_RATIO;

    for (int j = 0; j < SIFT_ORI_HIST_BINS; j++) {
        int l = (j == 0) ? SIFT_ORI_HIST_BINS - 1 : j - 1;
        int r_idx = (j + 1) % SIFT_ORI_HIST_BINS;
        if (s_hist[j] > s_hist[l] && s_hist[j] > s_hist[r_idx] && s_hist[j] >= mag_thr) {
            float bin = j + 0.5f * (s_hist[l] - s_hist[r_idx]) / (s_hist[l] - 2.0f * s_hist[j] + s_hist[r_idx]);
            if (bin < 0) bin += SIFT_ORI_HIST_BINS;
            if (bin >= SIFT_ORI_HIST_BINS) bin -= SIFT_ORI_HIST_BINS;

            int pos = atomicAdd(out_num_features, 1);
            if (pos < max_out_features) {
                FeatureDevice new_feat = feat;
                new_feat.ori = (PI2 * bin) / SIFT_ORI_HIST_BINS - M_PI;
                out_features[pos] = new_feat;
            }
        }
    }
}

// ===== GPU 描述子生成 =====

__device__ void d_interp_hist_entry(float* hist, int d, int n,
                                     float rbin, float cbin, float obin, float mag) {
    int r0 = (int)floorf(rbin);
    int c0 = (int)floorf(cbin);
    int o0 = (int)floorf(obin);
    float d_r = rbin - r0;
    float d_c = cbin - c0;
    float d_o = obin - o0;

    for (int rr = 0; rr <= 1; rr++) {
        int rb = r0 + rr;
        if (rb >= 0 && rb < d) {
            float v_r = mag * ((rr == 0) ? 1.0f - d_r : d_r);
            for (int cc = 0; cc <= 1; cc++) {
                int cb = c0 + cc;
                if (cb >= 0 && cb < d) {
                    float v_c = v_r * ((cc == 0) ? 1.0f - d_c : d_c);
                    for (int oo = 0; oo <= 1; oo++) {
                        int ob = (o0 + oo) % n;
                        float v_o = v_c * ((oo == 0) ? 1.0f - d_o : d_o);
                        atomicAdd(&hist[((rb * d) + cb) * n + ob], v_o);
                    }
                }
            }
        }
    }
}

__global__ void compute_descriptor_kernel(
    const float* const* gauss_pyr,
    const int* widths,
    const int* heights,
    FeatureDevice* features,
    int num_features,
    int d, int n
) {
    int feat_idx = blockIdx.x;
    if (feat_idx >= num_features) return;

    FeatureDevice feat = features[feat_idx];
    int w = widths[feat.octv];
    int h = heights[feat.octv];
    int r = feat.r;
    int c = feat.c;
    float ori = feat.ori;
    float scl = feat.scl_octv;

    __shared__ float s_hist[128];
    for (int i = threadIdx.x; i < d * d * n; i += blockDim.x)
        s_hist[i] = 0.0f;
    __syncthreads();

    float cos_t = cosf(ori);
    float sin_t = sinf(ori);
    float bins_per_rad = n / (2.0f * M_PI);
    float exp_denom = d * d * 0.5f;
    float hist_width = SIFT_DESCR_SCL_FCTR * scl;
    int radius = (int)(hist_width * sqrtf(2.0f) * (d + 1.0f) * 0.5f + 0.5f);

    const float* img = gauss_pyr[feat.octv * 6 + feat.intvl];
    int area = (2 * radius + 1) * (2 * radius + 1);

    for (int i = threadIdx.x; i < area; i += blockDim.x) {
        int dy = i / (2 * radius + 1) - radius;
        int dx = i % (2 * radius + 1) - radius;

        float c_rot = (dx * cos_t - dy * sin_t) / hist_width;
        float r_rot = (dx * sin_t + dy * cos_t) / hist_width;
        float rbin = r_rot + d / 2.0f - 0.5f;
        float cbin = c_rot + d / 2.0f - 0.5f;

        if (rbin > -1.0f && rbin < d && cbin > -1.0f && cbin < d) {
            int pr = r + dy;
            int pc = c + dx;
            if (pr > 0 && pr < h - 1 && pc > 0 && pc < w - 1) {
                float grad_mag, grad_ori;
                float dx_val = img[pr * w + pc + 1] - img[pr * w + pc - 1];
                float dy_val = img[(pr - 1) * w + pc] - img[(pr + 1) * w + pc];
                grad_mag = sqrtf(dx_val * dx_val + dy_val * dy_val);
                grad_ori = atan2f(dy_val, dx_val);

                grad_ori -= ori;
                while (grad_ori < 0.0f) grad_ori += 2.0f * M_PI;
                while (grad_ori >= 2.0f * M_PI) grad_ori -= 2.0f * M_PI;

                float obin = grad_ori * bins_per_rad;
                float weight = expf(-(c_rot * c_rot + r_rot * r_rot) / exp_denom);
                d_interp_hist_entry(s_hist, d, n, rbin, cbin, obin, grad_mag * weight);
            }
        }
    }
    __syncthreads();

    if (threadIdx.x != 0) return;

    float len_sq = 0.0f;
    for (int i = 0; i < d * d * n; i++) len_sq += s_hist[i] * s_hist[i];
    float len_inv = 1.0f / sqrtf(len_sq);
    for (int i = 0; i < d * d * n; i++) s_hist[i] *= len_inv;

    for (int i = 0; i < d * d * n; i++)
        if (s_hist[i] > SIFT_DESCR_MAG_THR) s_hist[i] = SIFT_DESCR_MAG_THR;

    len_sq = 0.0f;
    for (int i = 0; i < d * d * n; i++) len_sq += s_hist[i] * s_hist[i];
    len_inv = 1.0f / sqrtf(len_sq);
    for (int i = 0; i < d * d * n; i++) s_hist[i] *= len_inv;

    for (int i = 0; i < d * d * n; i++) {
        int int_val = (int)(SIFT_INT_DESCR_FCTR * s_hist[i]);
        features[feat_idx].descr[i] = (int_val < 255) ? int_val : 255;
    }
    features[feat_idx].d = d * d * n;
}

// ===== Host 接口 (使用 SafeMem) =====

static void check_cuda(cudaError_t err, const char* msg) {
    if (err != cudaSuccess) {
        fprintf(stderr, "CUDA error (%s): %s\n", msg, cudaGetErrorString(err));
        exit(1);
    }
}

int sift_detect_extrema_gpu(
    const float* const* d_dog_pyr,
    const int* d_widths,
    const int* d_heights,
    int octvs, int intvls,
    float contr_thr, float curv_thr,
    FeatureDevice** d_out_features,
    int* h_out_count
) {
    int max_features = 50000;  // 增加上限
    FeatureDevice* d_features;
    int* d_count;
    SAFE_CUDA_MALLOC(&d_features, max_features * sizeof(FeatureDevice));
    SAFE_CUDA_MALLOC(&d_count, sizeof(int));
    SAFE_CUDA_MEMSET(d_count, 0, sizeof(int));

    float prelim_contr_thr = 0.5f * contr_thr / intvls;

    for (int o = 0; o < octvs; o++) {
        int w = 0;
        cudaMemcpy(&w, d_widths + o, sizeof(int), cudaMemcpyDeviceToHost);
        int h = 0;
        cudaMemcpy(&h, d_heights + o, sizeof(int), cudaMemcpyDeviceToHost);

        for (int i = 1; i <= intvls; i++) {
            dim3 block(16, 16);
            dim3 grid((w + block.x - 1) / block.x, (h + block.y - 1) / block.y);

            unsigned char* d_mask;
            SAFE_CUDA_MALLOC(&d_mask, w * h * sizeof(unsigned char));
            SAFE_CUDA_MEMSET(d_mask, 0, w * h * sizeof(unsigned char));

            detect_extrema_kernel<<<grid, block>>>(
                d_dog_pyr, d_widths, d_heights, o, i, prelim_contr_thr, d_mask
            );
            check_cuda(cudaGetLastError(), "detect_extrema_kernel");

            int threads = 256;
            int blocks = (w * h + threads - 1) / threads;
            refine_features_kernel<<<blocks, threads>>>(
                d_dog_pyr, d_widths, d_heights, o, i, intvls,
                contr_thr, curv_thr, d_mask, d_features, d_count, max_features
            );
            check_cuda(cudaGetLastError(), "refine_features_kernel");

            SAFE_CUDA_FREE(d_mask);
        }
    }

    SAFE_CUDA_MEMCPY(h_out_count, d_count, sizeof(int), cudaMemcpyDeviceToHost);
    *d_out_features = d_features;
    SAFE_CUDA_FREE(d_count);
    return 0;
}

int sift_calc_oris_gpu(
    const float* const* d_gauss_pyr,
    const int* d_widths,
    const int* d_heights,
    FeatureDevice* d_features,
    int num_features,
    FeatureDevice** d_out_features,
    int* h_out_count
) {
    int max_out = num_features * 3;
    FeatureDevice* d_out;
    int* d_out_count;
    SAFE_CUDA_MALLOC(&d_out, max_out * sizeof(FeatureDevice));
    SAFE_CUDA_MALLOC(&d_out_count, sizeof(int));
    SAFE_CUDA_MEMSET(d_out_count, 0, sizeof(int));

    int threads = 128;
    calc_ori_kernel<<<num_features, threads>>>(
        d_gauss_pyr, d_widths, d_heights,
        d_features, num_features, d_out_count, d_out, max_out
    );
    check_cuda(cudaGetLastError(), "calc_ori_kernel");

    SAFE_CUDA_MEMCPY(h_out_count, d_out_count, sizeof(int), cudaMemcpyDeviceToHost);
    SAFE_CUDA_FREE(d_out_count);
    *d_out_features = d_out;
    return 0;
}

int sift_compute_descriptors_gpu(
    const float* const* d_gauss_pyr,
    const int* d_widths,
    const int* d_heights,
    FeatureDevice* d_features,
    int num_features,
    int d, int n
) {
    int threads = 128;
    compute_descriptor_kernel<<<num_features, threads>>>(
        d_gauss_pyr, d_widths, d_heights, d_features, num_features, d, n
    );
    check_cuda(cudaGetLastError(), "compute_descriptor_kernel");
    return 0;
}

2.3 sift_algorithm.h

#ifndef SIFT_ALGORITHM_H
#define SIFT_ALGORITHM_H

#ifdef __cplusplus
extern "C" {
#endif

#include "gray_image.h"
#include "sift_types.h"

// Dynamic array of features
typedef struct {
    Feature** data;
    int size;
    int capacity;
} FeatureArray;

FeatureArray* feature_array_new(void);
void feature_array_free(FeatureArray* arr);
void feature_array_push(FeatureArray* arr, Feature* feat);
void feature_array_sort_by_scale(FeatureArray* arr);

FeatureArray* sift_extract_features(const GrayImage* img);

#ifdef __cplusplus
}
#endif

#endif

2.4 sift_algorithm.cu

#include "sift_algorithm.h"
#include "sift_detect.h"
#include "safemem/safemem.h"
#include "image_ops.h"
#include <cuda_runtime.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#define SIFT_INIT_SIGMA 0.5
#define SIFT_IMG_BORDER 5
#define SIFT_MAX_INTERP_STEPS 5
#define SIFT_ORI_HIST_BINS 36
#define SIFT_ORI_SIG_FCTR 1.5
#define SIFT_ORI_RADIUS (3.0 * SIFT_ORI_SIG_FCTR)
#define SIFT_ORI_SMOOTH_PASSES 2
#define SIFT_ORI_PEAK_RATIO 0.8
#define SIFT_DESCR_SCL_FCTR 3.0
#define SIFT_DESCR_MAG_THR 0.2
#define SIFT_INT_DESCR_FCTR 512.0

#define SIFT_INTVLS 3
#define SIFT_SIGMA 1.6
#define SIFT_CONTR_THR 0.04
#define SIFT_CURV_THR 10
#define SIFT_IMG_DBL 1
#define SIFT_DESCR_WIDTH 4
#define SIFT_DESCR_HIST_BINS 8

FeatureArray* feature_array_new(void) {
    FeatureArray* arr = (FeatureArray*)malloc(sizeof(FeatureArray));
    arr->data = NULL;
    arr->size = 0;
    arr->capacity = 0;
    return arr;
}

void feature_array_free(FeatureArray* arr) {
    if (!arr) return;
    for (int i = 0; i < arr->size; i++) {
        feature_free(arr->data[i]);
    }
    free(arr->data);
    free(arr);
}

void feature_array_push(FeatureArray* arr, Feature* feat) {
    if (arr->size >= arr->capacity) {
        arr->capacity = arr->capacity == 0 ? 16 : arr->capacity * 2;
        arr->data = (Feature**)realloc(arr->data, arr->capacity * sizeof(Feature*));
    }
    arr->data[arr->size++] = feat;
}

static int compare_feature_scale(const void* a, const void* b) {
    Feature* fa = *(Feature**)a;
    Feature* fb = *(Feature**)b;
    if (fa->scl < fb->scl) return 1;
    if (fa->scl > fb->scl) return -1;
    return 0;
}

void feature_array_sort_by_scale(FeatureArray* arr) {
    if (arr->size > 1) {
        qsort(arr->data, arr->size, sizeof(Feature*), compare_feature_scale);
    }
}

// ===== GPU 内存管理辅助函数 =====

static void check_cuda(cudaError_t err, const char* msg) {
    if (err != cudaSuccess) {
        fprintf(stderr, "CUDA error (%s): %s\n", msg, cudaGetErrorString(err));
        exit(1);
    }
}

// 上传图像到 device (返回 device 指针)
static float* upload_image_to_device(const GrayImage* img) {
    float* d_data;
    SAFE_CUDA_MALLOC(&d_data, img->width * img->height * sizeof(float));
    SAFE_CUDA_MEMCPY(d_data, img->data, img->width * img->height * sizeof(float), cudaMemcpyHostToDevice);
    return d_data;
}

// 在 device 上执行高斯模糊
static float* gaussian_blur_device(const float* d_src, int w, int h, double sigma) {
    if (sigma <= 0.01) {
        float* d_clone;
        SAFE_CUDA_MALLOC(&d_clone, w * h * sizeof(float));
        SAFE_CUDA_MEMCPY(d_clone, d_src, w * h * sizeof(float), cudaMemcpyDeviceToDevice);
        return d_clone;
    }

    int size = (int)ceil(sigma * 6.0);
    if (size % 2 == 0) size++;
    int radius = size / 2;

    float* h_kernel = (float*)SAFE_MALLOC(size * sizeof(float));
    double sum = 0.0;
    for (int i = 0; i < size; i++) {
        double x = i - radius;
        h_kernel[i] = (float)exp(-(x * x) / (2.0 * sigma * sigma));
        sum += h_kernel[i];
    }
    for (int i = 0; i < size; i++) h_kernel[i] /= (float)sum;

    float *d_tmp, *d_dst, *d_kernel;
    SAFE_CUDA_MALLOC(&d_tmp, w * h * sizeof(float));
    SAFE_CUDA_MALLOC(&d_dst, w * h * sizeof(float));
    SAFE_CUDA_MALLOC(&d_kernel, size * sizeof(float));
    SAFE_CUDA_MEMCPY(d_kernel, h_kernel, size * sizeof(float), cudaMemcpyHostToDevice);

    dim3 block(16, 16);
    dim3 grid((w + block.x - 1) / block.x, (h + block.y - 1) / block.y);

    // 复用 image_ops.cu 中的核函数
    extern __global__ void gaussian_blur_h_kernel(const float* src, float* dst, int width, int height,
                                                   const float* kernel, int kernel_size, int radius);
    extern __global__ void gaussian_blur_v_kernel(const float* src, float* dst, int width, int height,
                                                   const float* kernel, int kernel_size, int radius);

    gaussian_blur_h_kernel<<<grid, block>>>(d_src, d_tmp, w, h, d_kernel, size, radius);
    check_cuda(cudaGetLastError(), "gaussian_blur_h");
    gaussian_blur_v_kernel<<<grid, block>>>(d_tmp, d_dst, w, h, d_kernel, size, radius);
    check_cuda(cudaGetLastError(), "gaussian_blur_v");

    SAFE_CUDA_FREE(d_tmp);
    SAFE_CUDA_FREE(d_kernel);
    SAFE_FREE(h_kernel);
    return d_dst;
}

// 在 device 上执行下采样
static float* downsample_device(const float* d_src, int srcW, int srcH) {
    int dstW = srcW / 2;
    int dstH = srcH / 2;
    float* d_dst;
    SAFE_CUDA_MALLOC(&d_dst, dstW * dstH * sizeof(float));

    dim3 block(16, 16);
    dim3 grid((dstW + block.x - 1) / block.x, (dstH + block.y - 1) / block.y);

    extern __global__ void downsample_kernel(const float* src, float* dst, int srcW, int srcH, int dstW, int dstH);
    downsample_kernel<<<grid, block>>>(d_src, d_dst, srcW, srcH, dstW, dstH);
    check_cuda(cudaGetLastError(), "downsample");
    return d_dst;
}

// 在 device 上执行减法 (DoG)
static float* subtract_device(const float* d_a, const float* d_b, int w, int h) {
    float* d_dst;
    SAFE_CUDA_MALLOC(&d_dst, w * h * sizeof(float));

    dim3 block(16, 16);
    dim3 grid((w + block.x - 1) / block.x, (h + block.y - 1) / block.y);

    extern __global__ void subtract_kernel(const float* a, const float* b, float* dst, int width, int height);
    subtract_kernel<<<grid, block>>>(d_a, d_b, d_dst, w, h);
    check_cuda(cudaGetLastError(), "subtract");
    return d_dst;
}

// ===== 构建 Device 金字塔 =====

typedef struct {
    float** data;   // device 指针数组
    int* widths;
    int* heights;
    int octvs;
    int layers;
} DevicePyr;

static DevicePyr* build_gauss_pyr_device(float* d_base, int baseW, int baseH,
                                          int octvs, int intvls, double sigma) {
    DevicePyr* pyr = (DevicePyr*)SAFE_MALLOC(sizeof(DevicePyr));
    pyr->octvs = octvs;
    pyr->layers = intvls + 3;
    pyr->data = (float**)SAFE_MALLOC(octvs * pyr->layers * sizeof(float*));
    pyr->widths = (int*)SAFE_MALLOC(octvs * sizeof(int));
    pyr->heights = (int*)SAFE_MALLOC(octvs * sizeof(int));

    double* sig = (double*)SAFE_MALLOC((intvls + 3) * sizeof(double));
    double k = pow(2.0, 1.0 / intvls);
    sig[0] = sigma;
    for (int i = 1; i < intvls + 3; i++) {
        double sig_prev = pow(k, i - 1) * sigma;
        double sig_total = sig_prev * k;
        sig[i] = sqrt(sig_total * sig_total - sig_prev * sig_prev);
    }

    for (int o = 0; o < octvs; o++) {
        int w = baseW >> o;
        int h = baseH >> o;
        pyr->widths[o] = w;
        pyr->heights[o] = h;

        for (int i = 0; i < intvls + 3; i++) {
            int idx = o * pyr->layers + i;
            if (o == 0 && i == 0) {
                // 克隆 base
                float* d_clone;
                SAFE_CUDA_MALLOC(&d_clone, w * h * sizeof(float));
                SAFE_CUDA_MEMCPY(d_clone, d_base, w * h * sizeof(float), cudaMemcpyDeviceToDevice);
                pyr->data[idx] = d_clone;
            } else if (i == 0) {
                pyr->data[idx] = downsample_device(pyr->data[(o-1) * pyr->layers + intvls], w * 2, h * 2);
            } else {
                pyr->data[idx] = gaussian_blur_device(pyr->data[idx - 1], w, h, sig[i]);
            }
        }
    }

    SAFE_FREE(sig);
    return pyr;
}

static DevicePyr* build_dog_pyr_device(DevicePyr* gauss_pyr, int octvs, int intvls) {
    DevicePyr* dog = (DevicePyr*)SAFE_MALLOC(sizeof(DevicePyr));
    dog->octvs = octvs;
    dog->layers = intvls + 2;
    dog->data = (float**)SAFE_MALLOC(octvs * dog->layers * sizeof(float*));
    dog->widths = (int*)SAFE_MALLOC(octvs * sizeof(int));
    dog->heights = (int*)SAFE_MALLOC(octvs * sizeof(int));
    memcpy(dog->widths, gauss_pyr->widths, octvs * sizeof(int));
    memcpy(dog->heights, gauss_pyr->heights, octvs * sizeof(int));

    for (int o = 0; o < octvs; o++) {
        for (int i = 0; i < intvls + 2; i++) {
            int idx = o * dog->layers + i;
            dog->data[idx] = subtract_device(
                gauss_pyr->data[o * gauss_pyr->layers + i + 1],
                gauss_pyr->data[o * gauss_pyr->layers + i],
                dog->widths[o], dog->heights[o]
            );
        }
    }
    return dog;
}

static void free_device_pyr(DevicePyr* pyr) {
    for (int i = 0; i < pyr->octvs * pyr->layers; i++) {
        SAFE_CUDA_FREE(pyr->data[i]);
    }
    SAFE_FREE(pyr->data);
    SAFE_FREE(pyr->widths);
    SAFE_FREE(pyr->heights);
    SAFE_FREE(pyr);
}

// ===== 辅助: 创建初始图像 =====

static GrayImage* create_init_img(const GrayImage* img, int img_dbl, double sigma) {
    GrayImage* gray = gray_image_clone(img);
    if (img_dbl != 0) {
        float sig_diff = (float)sqrt(sigma * sigma - SIFT_INIT_SIGMA * SIFT_INIT_SIGMA * 4.0);
        GrayImage* dbl = image_ops_resize_cubic(gray, img->width * 2, img->height * 2);
        gray_image_free(gray);
        GrayImage* blurred = image_ops_gaussian_blur(dbl, sig_diff);
        gray_image_free(dbl);
        return blurred;
    } else {
        float sig_diff = (float)sqrt(sigma * sigma - SIFT_INIT_SIGMA * SIFT_INIT_SIGMA);
        GrayImage* blurred = image_ops_gaussian_blur(gray, sig_diff);
        gray_image_free(gray);
        return blurred;
    }
}

// ===== 主特征提取函数 (GPU 优化版) =====

FeatureArray* sift_extract_features(const GrayImage* img) {
    // 1. 创建初始图像
    GrayImage* init_img = create_init_img(img, SIFT_IMG_DBL, SIFT_SIGMA);
    int octvs = (int)(log(fmin(init_img->width, init_img->height)) / log(2.0) - 2);
    if (octvs < 1) octvs = 1;

    // 2. 上传初始图像到 device
    float* d_init = upload_image_to_device(init_img);

    // 3. 在 device 上构建高斯金字塔
    DevicePyr* gauss_pyr = build_gauss_pyr_device(d_init, init_img->width, init_img->height,
                                                   octvs, SIFT_INTVLS, SIFT_SIGMA);
    SAFE_CUDA_FREE(d_init);

    // 4. 在 device 上构建 DoG 金字塔
    DevicePyr* dog_pyr = build_dog_pyr_device(gauss_pyr, octvs, SIFT_INTVLS);

    // 5. 上传金字塔元数据到 device
    float** d_dog_ptrs;
    int* d_widths;
    int* d_heights;
    SAFE_CUDA_MALLOC(&d_dog_ptrs, octvs * dog_pyr->layers * sizeof(float*));
    SAFE_CUDA_MEMCPY(d_dog_ptrs, dog_pyr->data, octvs * dog_pyr->layers * sizeof(float*), cudaMemcpyHostToDevice);
    SAFE_CUDA_MALLOC(&d_widths, octvs * sizeof(int));
    SAFE_CUDA_MEMCPY(d_widths, dog_pyr->widths, octvs * sizeof(int), cudaMemcpyHostToDevice);
    SAFE_CUDA_MALLOC(&d_heights, octvs * sizeof(int));
    SAFE_CUDA_MEMCPY(d_heights, dog_pyr->heights, octvs * sizeof(int), cudaMemcpyHostToDevice);

    float** d_gauss_ptrs;
    SAFE_CUDA_MALLOC(&d_gauss_ptrs, octvs * gauss_pyr->layers * sizeof(float*));
    SAFE_CUDA_MEMCPY(d_gauss_ptrs, gauss_pyr->data, octvs * gauss_pyr->layers * sizeof(float*), cudaMemcpyHostToDevice);

    // 6. GPU 极值检测
    FeatureDevice* d_features;
    int num_features = 0;
    sift_detect_extrema_gpu(d_dog_ptrs, d_widths, d_heights, octvs, SIFT_INTVLS,
                            SIFT_CONTR_THR, SIFT_CURV_THR, &d_features, &num_features);

    // 7. 计算尺度
    for (int i = 0; i < num_features; i++) {
        FeatureDevice feat;
        cudaMemcpy(&feat, d_features + i, sizeof(FeatureDevice), cudaMemcpyDeviceToHost);
        double intvl = feat.intvl + feat.subintvl;
        feat.scl = SIFT_SIGMA * pow(2.0, feat.octv + intvl / SIFT_INTVLS);
        feat.scl_octv = SIFT_SIGMA * pow(2.0, intvl / SIFT_INTVLS);
        cudaMemcpy(d_features + i, &feat, sizeof(FeatureDevice), cudaMemcpyHostToDevice);
    }

    // 8. GPU 方向分配
    FeatureDevice* d_features_with_ori;
    int num_features_with_ori = 0;
    sift_calc_oris_gpu(d_gauss_ptrs, d_widths, d_heights, d_features, num_features,
                       &d_features_with_ori, &num_features_with_ori);
    SAFE_CUDA_FREE(d_features);

    // 9. GPU 描述子生成
    sift_compute_descriptors_gpu(d_gauss_ptrs, d_widths, d_heights,
                                  d_features_with_ori, num_features_with_ori,
                                  SIFT_DESCR_WIDTH, SIFT_DESCR_HIST_BINS);

    // 10. 下载特征到 host
    FeatureArray* features = feature_array_new();
    FeatureDevice* h_features = (FeatureDevice*)SAFE_MALLOC(num_features_with_ori * sizeof(FeatureDevice));
    SAFE_CUDA_MEMCPY(h_features, d_features_with_ori, num_features_with_ori * sizeof(FeatureDevice), cudaMemcpyDeviceToHost);

    for (int i = 0; i < num_features_with_ori; i++) {
        Feature* feat = feature_new();
        feat->x = h_features[i].x;
        feat->y = h_features[i].y;
        feat->scl = h_features[i].scl;
        feat->ori = h_features[i].ori;
        feat->d = h_features[i].d;
        memcpy(feat->descr, h_features[i].descr, 128 * sizeof(double));
        feat->img_pt_x = h_features[i].img_pt_x;
        feat->img_pt_y = h_features[i].img_pt_y;
        feat->feature_data.r = h_features[i].r;
        feat->feature_data.c = h_features[i].c;
        feat->feature_data.octv = h_features[i].octv;
        feat->feature_data.intvl = h_features[i].intvl;
        feat->feature_data.subintvl = h_features[i].subintvl;
        feat->feature_data.scl_octv = h_features[i].scl_octv;
        feature_array_push(features, feat);
    }

    // 11. 调整图像双倍尺寸
    if (SIFT_IMG_DBL != 0) {
        for (int i = 0; i < features->size; i++) {
            Feature* feat = features->data[i];
            feat->x /= 2.0;
            feat->y /= 2.0;
            feat->scl /= 2.0;
            feat->img_pt_x = (float)(feat->img_pt_x / 2.0f);
            feat->img_pt_y = (float)(feat->img_pt_y / 2.0f);
        }
    }

    feature_array_sort_by_scale(features);

    // 清理
    SAFE_FREE(h_features);
    SAFE_CUDA_FREE(d_features_with_ori);
    SAFE_CUDA_FREE(d_dog_ptrs);
    SAFE_CUDA_FREE(d_gauss_ptrs);
    SAFE_CUDA_FREE(d_widths);
    SAFE_CUDA_FREE(d_heights);
    free_device_pyr(gauss_pyr);
    free_device_pyr(dog_pyr);
    gray_image_free(init_img);

    return features;
}

2.5 sift_matcher.h

#ifndef SIFT_MATCHER_H
#define SIFT_MATCHER_H

#ifdef __cplusplus
extern "C" {
#endif

#include "sift_types.h"
#include "sift_algorithm.h"

double sift_matcher_compute_similarity(const FeatureArray* features1, const FeatureArray* features2);
int sift_matcher_count_matches(const FeatureArray* features1, const FeatureArray* features2);
double sift_matcher_descriptor_distance(const Feature* a, const Feature* b);

#ifdef __cplusplus
}
#endif

#endif

2.6 sift_matcher.cu

#include "sift_matcher.h"
#include "safemem/safemem.h"
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
#include <stdlib.h>
#include <stdio.h>

// ===== 优化 1: float 精度 + 共享内存 tile =====

#define TILE_SIZE 32
#define DESCR_DIM 128

__global__ void __launch_bounds__(256, 2)
descriptor_distances_optimized_kernel(
    const float* __restrict__ descr1,
    const float* __restrict__ descr2_t,  // 转置存储: 128 x n2
    float* __restrict__ distances,
    int n1, int n2
) {
    __shared__ float s_descr1[TILE_SIZE][DESCR_DIM];  // 缓存 descr1 的 tile
    __shared__ float s_descr2[TILE_SIZE][DESCR_DIM];  // 缓存 descr2_t 的 tile

    int idx1 = blockIdx.x * TILE_SIZE + threadIdx.y;
    int idx2 = blockIdx.y * TILE_SIZE + threadIdx.x;

    float dist = 0.0f;

    // 分块加载 128 维描述子
    for (int tile = 0; tile < DESCR_DIM; tile += TILE_SIZE) {
        // 协作加载 descr1
        if (idx1 < n1 && tile + threadIdx.x < DESCR_DIM) {
            s_descr1[threadIdx.y][threadIdx.x] = descr1[idx1 * DESCR_DIM + tile + threadIdx.x];
        } else {
            s_descr1[threadIdx.y][threadIdx.x] = 0.0f;
        }

        // 协作加载 descr2_t (已转置,连续访问)
        if (idx2 < n2 && tile + threadIdx.y < DESCR_DIM) {
            s_descr2[threadIdx.x][threadIdx.y] = descr2_t[idx2 * DESCR_DIM + tile + threadIdx.y];
        } else {
            s_descr2[threadIdx.x][threadIdx.y] = 0.0f;
        }

        __syncthreads();

        // 计算部分距离
        #pragma unroll
        for (int k = 0; k < TILE_SIZE; k++) {
            float diff = s_descr1[threadIdx.y][k] - s_descr2[threadIdx.x][k];
            dist += diff * diff;
        }
        __syncthreads();
    }

    if (idx1 < n1 && idx2 < n2) {
        distances[idx1 * n2 + idx2] = sqrtf(dist);
    }
}

// ===== 优化 2: GPU 端 ratio test + 匹配计数 =====

__global__ void match_count_kernel(
    const float* __restrict__ distances,
    int n1, int n2,
    float ratio_thresh,
    int* __restrict__ match_count
) {
    int idx1 = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx1 >= n1) return;

    float best_dist = FLT_MAX;
    float second_best_dist = FLT_MAX;

    for (int j = 0; j < n2; j++) {
        float dist = distances[idx1 * n2 + j];
        if (dist < best_dist) {
            second_best_dist = best_dist;
            best_dist = dist;
        } else if (dist < second_best_dist) {
            second_best_dist = dist;
        }
    }

    if (second_best_dist > 0.0f && best_dist / second_best_dist < ratio_thresh) {
        atomicAdd(match_count, 1);
    }
}

// ===== Host fallback (保留用于对比测试) =====

double sift_matcher_descriptor_distance(const Feature* a, const Feature* b) {
    int len = (a->d < b->d) ? a->d : b->d;
    if (len == 0) {
        int la = FEATURE_MAX_D;
        int lb = FEATURE_MAX_D;
        len = (la < lb) ? la : lb;
    }

    double sum = 0.0;
    for (int i = 0; i < len; i++) {
        double diff = a->descr[i] - b->descr[i];
        sum += diff * diff;
    }
    return sqrt(sum);
}

static void check_cuda(cudaError_t err, const char* msg) {
    if (err != cudaSuccess) {
        fprintf(stderr, "CUDA error (%s): %s\n", msg, cudaGetErrorString(err));
        exit(1);
    }
}

// ===== 优化后的匹配函数 =====

int sift_matcher_count_matches(const FeatureArray* features1, const FeatureArray* features2) {
    if (!features1 || !features2 || features1->size == 0 || features2->size == 0)
        return 0;

    int n1 = features1->size;
    int n2 = features2->size;
    int d = FEATURE_MAX_D;

    // Flatten descriptors as float (优化: double -> float)
    float* h_descr1 = (float*)SAFE_MALLOC(n1 * d * sizeof(float));
    float* h_descr2 = (float*)SAFE_MALLOC(n2 * d * sizeof(float));
    for (int i = 0; i < n1; i++) {
        for (int j = 0; j < d; j++)
            h_descr1[i * d + j] = (float)features1->data[i]->descr[j];
    }
    for (int i = 0; i < n2; i++) {
        for (int j = 0; j < d; j++)
            h_descr2[i * d + j] = (float)features2->data[i]->descr[j];
    }

    // 转置 descr2 以优化全局内存访问
    float* h_descr2_t = (float*)SAFE_MALLOC(n2 * d * sizeof(float));
    for (int i = 0; i < n2; i++) {
        for (int j = 0; j < d; j++) {
            h_descr2_t[j * n2 + i] = h_descr2[i * d + j];
        }
    }

    float *d_descr1, *d_descr2_t, *d_distances;
    SAFE_CUDA_MALLOC(&d_descr1, n1 * d * sizeof(float));
    SAFE_CUDA_MALLOC(&d_descr2_t, n2 * d * sizeof(float));
    SAFE_CUDA_MALLOC(&d_distances, n1 * n2 * sizeof(float));

    SAFE_CUDA_MEMCPY(d_descr1, h_descr1, n1 * d * sizeof(float), cudaMemcpyHostToDevice);
    SAFE_CUDA_MEMCPY(d_descr2_t, h_descr2_t, n2 * d * sizeof(float), cudaMemcpyHostToDevice);

    // 启动优化核函数
    dim3 block(TILE_SIZE, TILE_SIZE);
    dim3 grid((n1 + TILE_SIZE - 1) / TILE_SIZE, (n2 + TILE_SIZE - 1) / TILE_SIZE);
    descriptor_distances_optimized_kernel<<<grid, block>>>(d_descr1, d_descr2_t, d_distances, n1, n2);
    check_cuda(cudaGetLastError(), "descriptor_distances_optimized_kernel");

    // GPU 端 ratio test
    int* d_match_count;
    SAFE_CUDA_MALLOC(&d_match_count, sizeof(int));
    SAFE_CUDA_MEMSET(d_match_count, 0, sizeof(int));

    int threads = 256;
    int blocks = (n1 + threads - 1) / threads;
    match_count_kernel<<<blocks, threads>>>(d_distances, n1, n2, 0.75f, d_match_count);
    check_cuda(cudaGetLastError(), "match_count_kernel");

    int h_match_count = 0;
    SAFE_CUDA_MEMCPY(&h_match_count, d_match_count, sizeof(int), cudaMemcpyDeviceToHost);

    SAFE_FREE(h_descr1);
    SAFE_FREE(h_descr2);
    SAFE_FREE(h_descr2_t);
    SAFE_CUDA_FREE(d_descr1);
    SAFE_CUDA_FREE(d_descr2_t);
    SAFE_CUDA_FREE(d_distances);
    SAFE_CUDA_FREE(d_match_count);

    return h_match_count;
}

double sift_matcher_compute_similarity(const FeatureArray* features1, const FeatureArray* features2) {
    if (!features1 || !features2 || features1->size == 0 || features2->size == 0)
        return 0.0;

    int matches = sift_matcher_count_matches(features1, features2);
    int minFeatures = (features1->size < features2->size) ? features1->size : features2->size;
    if (minFeatures == 0)
        return 0.0;

    double ratio = (double)matches / minFeatures;
    double score = 1.0 - exp(-3.0 * ratio);
    if (score < 0.0) score = 0.0;
    if (score > 1.0) score = 1.0;
    return score;
}

2.7 sift_types.h

#ifndef SIFT_TYPES_H
#define SIFT_TYPES_H

#ifdef __cplusplus
extern "C" {
#endif

#include <stddef.h>

typedef enum {
    FEATURE_OXFD,
    FEATURE_LOWE
} FeatureType;

typedef struct {
    int r;
    int c;
    int octv;
    int intvl;
    double subintvl;
    double scl_octv;
} DetectionData;

#define FEATURE_MAX_D 128

typedef struct Feature {
    double x;
    double y;
    double a;
    double b;
    double c;
    double scl;
    double ori;
    int d;
    double descr[FEATURE_MAX_D];
    FeatureType type;
    int category;
    struct Feature* fwd_match;
    struct Feature* bck_match;
    struct Feature* mdl_match;
    float img_pt_x;
    float img_pt_y;
    DetectionData feature_data;
} Feature;

Feature* feature_new(void);
void feature_free(Feature* feat);
Feature* feature_clone(const Feature* feat);

#ifdef __cplusplus
}
#endif

#endif

2.8 sift_types.c

#include "sift_types.h"
#include <stdlib.h>
#include <string.h>

Feature* feature_new(void) {
    Feature* feat = (Feature*)calloc(1, sizeof(Feature));
    if (feat) {
        memset(feat->descr, 0, sizeof(feat->descr));
        feat->type = FEATURE_LOWE;
    }
    return feat;
}

void feature_free(Feature* feat) {
    free(feat);
}

Feature* feature_clone(const Feature* feat) {
    if (!feat) return NULL;
    Feature* new_feat = feature_new();
    if (!new_feat) return NULL;
    new_feat->x = feat->x;
    new_feat->y = feat->y;
    new_feat->a = feat->a;
    new_feat->b = feat->b;
    new_feat->c = feat->c;
    new_feat->scl = feat->scl;
    new_feat->ori = feat->ori;
    new_feat->d = feat->d;
    new_feat->type = feat->type;
    new_feat->category = feat->category;
    new_feat->img_pt_x = feat->img_pt_x;
    new_feat->img_pt_y = feat->img_pt_y;
    new_feat->feature_data = feat->feature_data;
    memcpy(new_feat->descr, feat->descr, sizeof(feat->descr));
    return new_feat;
}

3. 图像操作模块

3.1 image_ops.h

#ifndef IMAGE_OPS_H
#define IMAGE_OPS_H

#ifdef __cplusplus
extern "C" {
#endif

#include "gray_image.h"

GrayImage* image_ops_gaussian_blur(const GrayImage* src, double sigma);
GrayImage* image_ops_downsample(const GrayImage* src);
GrayImage* image_ops_resize_cubic(const GrayImage* src, int newW, int newH);
GrayImage* image_ops_subtract(const GrayImage* a, const GrayImage* b);

#ifdef __cplusplus
}
#endif

#endif

3.2 image_ops.cu

#include "image_ops.h"
#include <cuda_runtime.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

// CUDA kernel for horizontal Gaussian blur
__global__ void gaussian_blur_h_kernel(const float* src, float* dst, int width, int height,
                                       const float* kernel, int kernel_size, int radius) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;
    if (x >= width || y >= height) return;

    float val = 0.0f;
    for (int k = 0; k < kernel_size; k++) {
        int px = x + k - radius;
        if (px < 0) px = 0;
        if (px >= width) px = width - 1;
        val += src[y * width + px] * kernel[k];
    }
    dst[y * width + x] = val;
}

// CUDA kernel for vertical Gaussian blur
__global__ void gaussian_blur_v_kernel(const float* src, float* dst, int width, int height,
                                       const float* kernel, int kernel_size, int radius) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;
    if (x >= width || y >= height) return;

    float val = 0.0f;
    for (int k = 0; k < kernel_size; k++) {
        int py = y + k - radius;
        if (py < 0) py = 0;
        if (py >= height) py = height - 1;
        val += src[py * width + x] * kernel[k];
    }
    dst[y * width + x] = val;
}

// CUDA kernel for downsample
__global__ void downsample_kernel(const float* src, float* dst, int srcW, int srcH, int dstW, int dstH) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;
    if (x >= dstW || y >= dstH) return;
    dst[y * dstW + x] = src[(y * 2) * srcW + (x * 2)];
}

// CUDA kernel for subtract
__global__ void subtract_kernel(const float* a, const float* b, float* dst, int width, int height) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;
    if (x >= width || y >= height) return;
    int idx = y * width + x;
    dst[idx] = a[idx] - b[idx];
}

static void check_cuda(cudaError_t err, const char* msg) {
    if (err != cudaSuccess) {
        fprintf(stderr, "CUDA error (%s): %s\n", msg, cudaGetErrorString(err));
        exit(1);
    }
}

GrayImage* image_ops_gaussian_blur(const GrayImage* src, double sigma) {
    if (sigma <= 0.01) {
        return gray_image_clone(src);
    }

    int size = (int)ceil(sigma * 6.0);
    if (size % 2 == 0) size++;
    int radius = size / 2;

    float* h_kernel = (float*)malloc(size * sizeof(float));
    double sum = 0.0;
    for (int i = 0; i < size; i++) {
        double x = i - radius;
        h_kernel[i] = (float)exp(-(x * x) / (2.0 * sigma * sigma));
        sum += h_kernel[i];
    }
    for (int i = 0; i < size; i++) {
        h_kernel[i] /= (float)sum;
    }

    int w = src->width;
    int h = src->height;
    int num_pixels = w * h;

    float *d_src, *d_tmp, *d_dst, *d_kernel;
    check_cuda(cudaMalloc(&d_src, num_pixels * sizeof(float)), "cudaMalloc d_src");
    check_cuda(cudaMalloc(&d_tmp, num_pixels * sizeof(float)), "cudaMalloc d_tmp");
    check_cuda(cudaMalloc(&d_dst, num_pixels * sizeof(float)), "cudaMalloc d_dst");
    check_cuda(cudaMalloc(&d_kernel, size * sizeof(float)), "cudaMalloc d_kernel");

    check_cuda(cudaMemcpy(d_src, src->data, num_pixels * sizeof(float), cudaMemcpyHostToDevice), "memcpy src");
    check_cuda(cudaMemcpy(d_kernel, h_kernel, size * sizeof(float), cudaMemcpyHostToDevice), "memcpy kernel");

    dim3 block(16, 16);
    dim3 grid((w + block.x - 1) / block.x, (h + block.y - 1) / block.y);

    gaussian_blur_h_kernel<<<grid, block>>>(d_src, d_tmp, w, h, d_kernel, size, radius);
    check_cuda(cudaGetLastError(), "gaussian_blur_h_kernel");
    gaussian_blur_v_kernel<<<grid, block>>>(d_tmp, d_dst, w, h, d_kernel, size, radius);
    check_cuda(cudaGetLastError(), "gaussian_blur_v_kernel");

    GrayImage* dst = gray_image_create(w, h);
    check_cuda(cudaMemcpy(dst->data, d_dst, num_pixels * sizeof(float), cudaMemcpyDeviceToHost), "memcpy dst");

    cudaFree(d_src);
    cudaFree(d_tmp);
    cudaFree(d_dst);
    cudaFree(d_kernel);
    free(h_kernel);

    return dst;
}

GrayImage* image_ops_downsample(const GrayImage* src) {
    int newW = src->width / 2;
    int newH = src->height / 2;
    GrayImage* dst = gray_image_create(newW, newH);

    float *d_src, *d_dst;
    check_cuda(cudaMalloc(&d_src, src->width * src->height * sizeof(float)), "cudaMalloc d_src");
    check_cuda(cudaMalloc(&d_dst, newW * newH * sizeof(float)), "cudaMalloc d_dst");

    check_cuda(cudaMemcpy(d_src, src->data, src->width * src->height * sizeof(float), cudaMemcpyHostToDevice), "memcpy src");

    dim3 block(16, 16);
    dim3 grid((newW + block.x - 1) / block.x, (newH + block.y - 1) / block.y);
    downsample_kernel<<<grid, block>>>(d_src, d_dst, src->width, src->height, newW, newH);
    check_cuda(cudaGetLastError(), "downsample_kernel");

    check_cuda(cudaMemcpy(dst->data, d_dst, newW * newH * sizeof(float), cudaMemcpyDeviceToHost), "memcpy dst");

    cudaFree(d_src);
    cudaFree(d_dst);
    return dst;
}

GrayImage* image_ops_subtract(const GrayImage* a, const GrayImage* b) {
    int w = a->width;
    int h = a->height;
    GrayImage* dst = gray_image_create(w, h);

    float *d_a, *d_b, *d_dst;
    check_cuda(cudaMalloc(&d_a, w * h * sizeof(float)), "cudaMalloc d_a");
    check_cuda(cudaMalloc(&d_b, w * h * sizeof(float)), "cudaMalloc d_b");
    check_cuda(cudaMalloc(&d_dst, w * h * sizeof(float)), "cudaMalloc d_dst");

    check_cuda(cudaMemcpy(d_a, a->data, w * h * sizeof(float), cudaMemcpyHostToDevice), "memcpy a");
    check_cuda(cudaMemcpy(d_b, b->data, w * h * sizeof(float), cudaMemcpyHostToDevice), "memcpy b");

    dim3 block(16, 16);
    dim3 grid((w + block.x - 1) / block.x, (h + block.y - 1) / block.y);
    subtract_kernel<<<grid, block>>>(d_a, d_b, d_dst, w, h);
    check_cuda(cudaGetLastError(), "subtract_kernel");

    check_cuda(cudaMemcpy(dst->data, d_dst, w * h * sizeof(float), cudaMemcpyDeviceToHost), "memcpy dst");

    cudaFree(d_a);
    cudaFree(d_b);
    cudaFree(d_dst);
    return dst;
}

// Bicubic resize (CPU fallback - complex interpolation not easily parallelized simply)
static double cubic_kernel(double x) {
    x = fabs(x);
    if (x <= 1.0)
        return 1.5 * x * x * x - 2.5 * x * x + 1.0;
    if (x <= 2.0)
        return -0.5 * x * x * x + 2.5 * x * x - 4.0 * x + 2.0;
    return 0.0;
}

static int clamp(int v, int min, int max) {
    return v < min ? min : (v > max ? max : v);
}

static double bicubic_interpolate(const GrayImage* img, double x, double y) {
    int ix = (int)floor(x);
    int iy = (int)floor(y);
    double dx = x - ix;
    double dy = y - iy;

    double sum = 0.0;
    double wsum = 0.0;
    for (int j = -1; j <= 2; j++) {
        for (int i = -1; i <= 2; i++) {
            int px = clamp(ix + i, 0, img->width - 1);
            int py = clamp(iy + j, 0, img->height - 1);
            double wx = cubic_kernel(dx - i);
            double wy = cubic_kernel(dy - j);
            double w = wx * wy;
            sum += gray_image_get(img, py, px) * w;
            wsum += w;
        }
    }
    return wsum > 0.0 ? sum / wsum : 0.0;
}

GrayImage* image_ops_resize_cubic(const GrayImage* src, int newW, int newH) {
    GrayImage* dst = gray_image_create(newW, newH);
    double scaleX = (double)src->width / newW;
    double scaleY = (double)src->height / newH;

    for (int y = 0; y < newH; y++) {
        for (int x = 0; x < newW; x++) {
            double sx = (x + 0.5) * scaleX - 0.5;
            double sy = (y + 0.5) * scaleY - 0.5;
            gray_image_set(dst, y, x, (float)bicubic_interpolate(src, sx, sy));
        }
    }
    return dst;
}

3.3 gray_image.h

#ifndef GRAY_IMAGE_H
#define GRAY_IMAGE_H

#ifdef __cplusplus
extern "C" {
#endif

#include <stddef.h>

typedef struct {
    int width;
    int height;
    float* data; // row-major: data[y * width + x]
} GrayImage;

GrayImage* gray_image_create(int width, int height);
GrayImage* gray_image_create_from_data(int width, int height, const float* data);
void gray_image_free(GrayImage* img);
GrayImage* gray_image_clone(const GrayImage* img);

// Load from file using stb_image
GrayImage* gray_image_load(const char* filename);

// Access pixel
static inline float gray_image_get(const GrayImage* img, int row, int col) {
    return img->data[row * img->width + col];
}

static inline void gray_image_set(GrayImage* img, int row, int col, float val) {
    img->data[row * img->width + col] = val;
}

#ifdef __cplusplus
}
#endif

#endif

3.4 gray_image.c

#include "gray_image.h"
#include "stb_image.h"
#include <stdlib.h>
#include <string.h>
#include <math.h>

GrayImage* gray_image_create(int width, int height) {
    GrayImage* img = (GrayImage*)malloc(sizeof(GrayImage));
    if (!img) return NULL;
    img->width = width;
    img->height = height;
    img->data = (float*)calloc(width * height, sizeof(float));
    if (!img->data) {
        free(img);
        return NULL;
    }
    return img;
}

GrayImage* gray_image_create_from_data(int width, int height, const float* data) {
    GrayImage* img = gray_image_create(width, height);
    if (!img) return NULL;
    memcpy(img->data, data, width * height * sizeof(float));
    return img;
}

void gray_image_free(GrayImage* img) {
    if (img) {
        free(img->data);
        free(img);
    }
}

GrayImage* gray_image_clone(const GrayImage* img) {
    if (!img) return NULL;
    return gray_image_create_from_data(img->width, img->height, img->data);
}

GrayImage* gray_image_load(const char* filename) {
    int w, h, channels;
    unsigned char* pixels = stbi_load(filename, &w, &h, &channels, 0);
    if (!pixels) return NULL;

    GrayImage* img = gray_image_create(w, h);
    if (!img) {
        stbi_image_free(pixels);
        return NULL;
    }

    for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
            int idx = (y * w + x) * channels;
            float r, g, b;
            if (channels >= 3) {
                r = pixels[idx + 0] / 255.0f;
                g = pixels[idx + 1] / 255.0f;
                b = pixels[idx + 2] / 255.0f;
            } else if (channels == 1) {
                r = g = b = pixels[idx] / 255.0f;
            } else {
                r = g = b = 0;
            }
            float gray = r * 0.299f + g * 0.587f + b * 0.114f;
            gray_image_set(img, y, x, gray);
        }
    }

    stbi_image_free(pixels);
    return img;
}

4. CUDA Native 导出层

4.1 image_similarity.h

#ifndef IMAGE_SIMILARITY_H
#define IMAGE_SIMILARITY_H

#ifdef __cplusplus
extern "C" {
#endif

#include "gray_image.h"
#include "sift_algorithm.h"

// Simple hash map for template cache
typedef struct TemplateCacheEntry {
    char* label;
    FeatureArray* features;
    struct TemplateCacheEntry* next;
} TemplateCacheEntry;

typedef struct {
    TemplateCacheEntry** buckets;
    int bucket_count;
} TemplateCache;

typedef struct {
    TemplateCache* cache;
} ImageSimilarityEvaluator;

ImageSimilarityEvaluator* evaluator_new(void);
void evaluator_free(ImageSimilarityEvaluator* eval);

double evaluator_evaluate(ImageSimilarityEvaluator* eval, const GrayImage* currentImage,
                          const char* label, const GrayImage* correctImage);

void evaluator_preload_template(ImageSimilarityEvaluator* eval, const char* label,
                                const GrayImage* correctImage);
void evaluator_clear_cache(ImageSimilarityEvaluator* eval, const char* label);
void evaluator_clear_all_cache(ImageSimilarityEvaluator* eval);

#ifdef __cplusplus
}
#endif

#endif

4.2 image_similarity.cu

#include "image_similarity.h"
#include "sift_matcher.h"
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

// For cross-platform compatibility
#ifdef _WIN32
#include <string.h>
#define strcasecmp _stricmp
#define strdup _strdup
#else
#include <strings.h>
#endif

static unsigned int hash_string(const char* str) {
    unsigned int hash = 5381;
    int c;
    while ((c = *str++))
        hash = ((hash << 5) + hash) + c;
    return hash;
}

static TemplateCache* template_cache_new(void) {
    TemplateCache* cache = (TemplateCache*)malloc(sizeof(TemplateCache));
    cache->bucket_count = 64;
    cache->buckets = (TemplateCacheEntry**)calloc(cache->bucket_count, sizeof(TemplateCacheEntry*));
    return cache;
}

static void template_cache_free(TemplateCache* cache) {
    if (!cache) return;
    for (int i = 0; i < cache->bucket_count; i++) {
        TemplateCacheEntry* entry = cache->buckets[i];
        while (entry) {
            TemplateCacheEntry* next = entry->next;
            free(entry->label);
            feature_array_free(entry->features);
            free(entry);
            entry = next;
        }
    }
    free(cache->buckets);
    free(cache);
}

static FeatureArray* template_cache_get(TemplateCache* cache, const char* label) {
    unsigned int h = hash_string(label) % cache->bucket_count;
    TemplateCacheEntry* entry = cache->buckets[h];
    while (entry) {
        if (strcasecmp(entry->label, label) == 0)
            return entry->features;
        entry = entry->next;
    }
    return NULL;
}

static void template_cache_set(TemplateCache* cache, const char* label, FeatureArray* features) {
    unsigned int h = hash_string(label) % cache->bucket_count;
    TemplateCacheEntry* entry = cache->buckets[h];
    while (entry) {
        if (strcasecmp(entry->label, label) == 0) {
            feature_array_free(entry->features);
            entry->features = features;
            return;
        }
        entry = entry->next;
    }

    entry = (TemplateCacheEntry*)malloc(sizeof(TemplateCacheEntry));
    entry->label = strdup(label);
    entry->features = features;
    entry->next = cache->buckets[h];
    cache->buckets[h] = entry;
}

static void template_cache_remove(TemplateCache* cache, const char* label) {
    unsigned int h = hash_string(label) % cache->bucket_count;
    TemplateCacheEntry** p = &cache->buckets[h];
    while (*p) {
        if (strcasecmp((*p)->label, label) == 0) {
            TemplateCacheEntry* to_remove = *p;
            *p = (*p)->next;
            free(to_remove->label);
            feature_array_free(to_remove->features);
            free(to_remove);
            return;
        }
        p = &(*p)->next;
    }
}

static FeatureArray* extract_features(const GrayImage* img) {
    return sift_extract_features(img);
}

ImageSimilarityEvaluator* evaluator_new(void) {
    ImageSimilarityEvaluator* eval = (ImageSimilarityEvaluator*)malloc(sizeof(ImageSimilarityEvaluator));
    eval->cache = template_cache_new();
    return eval;
}

void evaluator_free(ImageSimilarityEvaluator* eval) {
    if (!eval) return;
    template_cache_free(eval->cache);
    free(eval);
}

double evaluator_evaluate(ImageSimilarityEvaluator* eval, const GrayImage* currentImage,
                          const char* label, const GrayImage* correctImage) {
    if (!currentImage || !label || strlen(label) == 0)
        return 0.0;

    FeatureArray* templateFeatures = NULL;
    if (correctImage != NULL) {
        templateFeatures = extract_features(correctImage);
        template_cache_set(eval->cache, label, templateFeatures);
    } else {
        templateFeatures = template_cache_get(eval->cache, label);
    }

    if (!templateFeatures)
        return 0.0;

    FeatureArray* currentFeatures = extract_features(currentImage);

    if (currentFeatures->size == 0 || templateFeatures->size == 0) {
        feature_array_free(currentFeatures);
        return 0.0;
    }

    double similarity = sift_matcher_compute_similarity(currentFeatures, templateFeatures);
    feature_array_free(currentFeatures);
    return similarity;
}

void evaluator_preload_template(ImageSimilarityEvaluator* eval, const char* label,
                                const GrayImage* correctImage) {
    if (!correctImage || !label || strlen(label) == 0)
        return;

    FeatureArray* features = extract_features(correctImage);
    template_cache_set(eval->cache, label, features);
}

void evaluator_clear_cache(ImageSimilarityEvaluator* eval, const char* label) {
    if (!label) return;
    template_cache_remove(eval->cache, label);
}

void evaluator_clear_all_cache(ImageSimilarityEvaluator* eval) {
    template_cache_free(eval->cache);
    eval->cache = template_cache_new();
}

5. C# 封装库

5.1 CudaBinarizeLib.cs

using System;
using System.IO;
using System.Runtime.InteropServices;

namespace CudaSharp
{
    /// <summary>
    /// CUDA 图像二值化库的配置参数
    /// </summary>
    public class BinarizeConfig
    {
        /// <summary>Gamma 校正值 (1.0 = 禁用)</summary>
        public float Gamma { get; set; } = 1.0f;

        /// <summary>阈值偏移,范围 -0.1 ~ 0.1</summary>
        public float Offset { get; set; } = 0.0f;

        /// <summary>窗口半径,推荐 15-25</summary>
        public int WinRadius { get; set; } = 25;

        /// <summary>Sauvola 敏感度,范围 0.1-0.5</summary>
        public float SauvolaK { get; set; } = 0.15f;

        /// <summary>是否使用 Sauvola 算法</summary>
        public bool UseSauvola { get; set; } = true;
    }

    /// <summary>
    /// CUDA 图像二值化处理器
    /// </summary>
    public class CudaBinarizer : IDisposable
    {
        private IntPtr _handle;
        private bool _disposed = false;

        private const string DllName = "CudaSharpNative.dll";

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern IntPtr CreateBinarizer(float gamma, float offset, int winRadius,
                                                     float sauvolaK, [MarshalAs(UnmanagedType.U1)] bool useSauvola);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern void DestroyBinarizer(IntPtr handle);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern int ProcessStream(IntPtr handle,
                                                [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 2)] byte[] inputData,
                                                int inputSize,
                                                out IntPtr outputData,
                                                out int outputSize);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern void FreeMemory(IntPtr ptr);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern IntPtr GetCudaLastError();

        /// <summary>
        /// 创建 CUDA 二值化处理器
        /// </summary>
        public CudaBinarizer(BinarizeConfig config)
        {
            if (config == null)
                throw new ArgumentNullException(nameof(config));

            _handle = CreateBinarizer(config.Gamma, config.Offset, config.WinRadius,
                                      config.SauvolaK, config.UseSauvola);
            if (_handle == IntPtr.Zero)
            {
                throw new InvalidOperationException("Failed to create CUDA binarizer: " + GetCudaLastErrorMessage());
            }
        }

        /// <summary>
        /// 处理内存中的图像字节数组
        /// </summary>
        /// <param name="inputData">输入图像数据</param>
        /// <returns>二值化后的 JPEG 数据</returns>
        public byte[] ProcessMemory(byte[] inputData)
        {
            if (_disposed) throw new ObjectDisposedException(nameof(CudaBinarizer));
            if (inputData == null || inputData.Length == 0)
                throw new ArgumentException("Input data cannot be null or empty", nameof(inputData));

            IntPtr outputPtr = IntPtr.Zero;
            int outputSize = 0;

            int result = ProcessStream(_handle, inputData, inputData.Length, out outputPtr, out outputSize);

            if (result != 0 || outputPtr == IntPtr.Zero)
            {
                throw new InvalidOperationException("Failed to process image: " + GetCudaLastErrorMessage());
            }

            try
            {
                byte[] outputData = new byte[outputSize];
                Marshal.Copy(outputPtr, outputData, 0, outputSize);
                return outputData;
            }
            finally
            {
                FreeMemory(outputPtr);
            }
        }

        /// <summary>
        /// 处理输入流中的图像,并将结果写入输出流。
        /// 支持任意格式输入 (BMP, JPEG, PNG 等),输出为 BMP 格式。
        /// </summary>
        /// <param name="inputStream">输入图像流</param>
        /// <param name="outputStream">输出图像流</param>
        public void ProcessStream(Stream inputStream, Stream outputStream)
        {
            if (_disposed) throw new ObjectDisposedException(nameof(CudaBinarizer));
            if (inputStream == null) throw new ArgumentNullException(nameof(inputStream));
            if (outputStream == null) throw new ArgumentNullException(nameof(outputStream));

            byte[] inputData;
            if (inputStream is MemoryStream memStream)
            {
                inputData = memStream.ToArray();
            }
            else
            {
                using (var tempMs = new MemoryStream())
                {
                    inputStream.CopyTo(tempMs);
                    inputData = tempMs.ToArray();
                }
            }

            byte[] outputData = ProcessMemory(inputData);
            outputStream.Write(outputData, 0, outputData.Length);
        }

        /// <summary>
        /// 获取最后的错误信息
        /// </summary>
        public string GetCudaLastErrorMessage()
        {
            IntPtr errorPtr = GetCudaLastError();
            return errorPtr != IntPtr.Zero ? Marshal.PtrToStringAnsi(errorPtr) : "Unknown error";
        }

        public void Dispose()
        {
            if (!_disposed)
            {
                if (_handle != IntPtr.Zero)
                {
                    DestroyBinarizer(_handle);
                    _handle = IntPtr.Zero;
                }
                _disposed = true;
            }
            GC.SuppressFinalize(this);
        }

        ~CudaBinarizer()
        {
            Dispose();
        }
    }

    /// <summary>
    /// CUDA SIFT 图像相似度处理器
    /// </summary>
    public static class CudaSift
    {
        private const string DllName = "CudaSharpNative.dll";

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern double CompareImagesSift(
            [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] byte[] data1, int size1,
            [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 3)] byte[] data2, int size2);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern int CountMatchesSift(
            [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] byte[] data1, int size1,
            [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 3)] byte[] data2, int size2);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern IntPtr GetCudaLastError();

        private static string GetCudaLastErrorMessage()
        {
            IntPtr errorPtr = GetCudaLastError();
            return errorPtr != IntPtr.Zero ? Marshal.PtrToStringAnsi(errorPtr) : "Unknown error";
        }

        /// <summary>
        /// 比较两幅图像的 SIFT 相似度,返回 0.0 ~ 1.0 的分数
        /// </summary>
        public static double CompareImages(byte[] image1, byte[] image2)
        {
            if (image1 == null || image1.Length == 0)
                throw new ArgumentException("Image1 cannot be null or empty", nameof(image1));
            if (image2 == null || image2.Length == 0)
                throw new ArgumentException("Image2 cannot be null or empty", nameof(image2));

            double result = CompareImagesSift(image1, image1.Length, image2, image2.Length);
            if (result < 0)
                throw new InvalidOperationException("Failed to compare images: " + GetCudaLastErrorMessage());
            return result;
        }

        /// <summary>
        /// 计算两幅图像的 SIFT 匹配点数量
        /// </summary>
        public static int CountMatches(byte[] image1, byte[] image2)
        {
            if (image1 == null || image1.Length == 0)
                throw new ArgumentException("Image1 cannot be null or empty", nameof(image1));
            if (image2 == null || image2.Length == 0)
                throw new ArgumentException("Image2 cannot be null or empty", nameof(image2));

            int result = CountMatchesSift(image1, image1.Length, image2, image2.Length);
            if (result < 0)
                throw new InvalidOperationException("Failed to count matches: " + GetCudaLastErrorMessage());
            return result;
        }
    }

    /// <summary>
    /// CUDA SIFT 图像相似度评估器(支持模板缓存)
    /// </summary>
    public class CudaImageSimilarityEvaluator : IDisposable
    {
        private IntPtr _handle;
        private bool _disposed = false;
        private const string DllName = "CudaSharpNative.dll";

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern IntPtr EvaluatorNew();

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern void EvaluatorFree(IntPtr handle);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern double EvaluatorEvaluate(IntPtr handle,
            [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] byte[] currentData, int currentSize,
            [MarshalAs(UnmanagedType.LPStr)] string label,
            [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 3)] byte[] correctData, int correctSize);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern void EvaluatorPreloadTemplate(IntPtr handle,
            [MarshalAs(UnmanagedType.LPStr)] string label,
            [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 2)] byte[] correctData, int correctSize);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern void EvaluatorClearCache(IntPtr handle,
            [MarshalAs(UnmanagedType.LPStr)] string label);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern void EvaluatorClearAllCache(IntPtr handle);

        [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
        private static extern IntPtr GetCudaLastError();

        /// <summary>
        /// 创建图像相似度评估器
        /// </summary>
        public CudaImageSimilarityEvaluator()
        {
            _handle = EvaluatorNew();
            if (_handle == IntPtr.Zero)
                throw new InvalidOperationException("Failed to create evaluator: " + GetCudaLastErrorMessage());
        }

        /// <summary>
        /// 评估当前图像与模板图像的相似度
        /// </summary>
        /// <param name="currentImage">当前图像数据</param>
        /// <param name="label">模板标签</param>
        /// <param name="correctImage">模板图像数据(传入后会更新缓存;可为 null 使用缓存)</param>
        /// <returns>0.0 ~ 1.0 的相似度分数</returns>
        public double Evaluate(byte[] currentImage, string label, byte[] correctImage = null)
        {
            if (_disposed) throw new ObjectDisposedException(nameof(CudaImageSimilarityEvaluator));
            if (currentImage == null || currentImage.Length == 0)
                throw new ArgumentException("Current image cannot be null or empty", nameof(currentImage));
            if (string.IsNullOrEmpty(label))
                throw new ArgumentException("Label cannot be null or empty", nameof(label));

            double result = EvaluatorEvaluate(_handle,
                currentImage, currentImage.Length, label,
                correctImage, correctImage?.Length ?? 0);
            if (result < 0)
                throw new InvalidOperationException("Failed to evaluate similarity: " + GetCudaLastErrorMessage());
            return result;
        }

        /// <summary>
        /// 预加载模板图像到缓存
        /// </summary>
        public void PreloadTemplate(string label, byte[] correctImage)
        {
            if (_disposed) throw new ObjectDisposedException(nameof(CudaImageSimilarityEvaluator));
            if (string.IsNullOrEmpty(label))
                throw new ArgumentException("Label cannot be null or empty", nameof(label));
            if (correctImage == null || correctImage.Length == 0)
                throw new ArgumentException("Template image cannot be null or empty", nameof(correctImage));

            EvaluatorPreloadTemplate(_handle, label, correctImage, correctImage.Length);
        }

        /// <summary>
        /// 清除指定标签的缓存
        /// </summary>
        public void ClearCache(string label)
        {
            if (_disposed) throw new ObjectDisposedException(nameof(CudaImageSimilarityEvaluator));
            EvaluatorClearCache(_handle, label);
        }

        /// <summary>
        /// 清除所有缓存
        /// </summary>
        public void ClearAllCache()
        {
            if (_disposed) throw new ObjectDisposedException(nameof(CudaImageSimilarityEvaluator));
            EvaluatorClearAllCache(_handle);
        }

        private static string GetCudaLastErrorMessage()
        {
            IntPtr errorPtr = GetCudaLastError();
            return errorPtr != IntPtr.Zero ? Marshal.PtrToStringAnsi(errorPtr) : "Unknown error";
        }

        public void Dispose()
        {
            if (!_disposed)
            {
                if (_handle != IntPtr.Zero)
                {
                    EvaluatorFree(_handle);
                    _handle = IntPtr.Zero;
                }
                _disposed = true;
            }
            GC.SuppressFinalize(this);
        }

        ~CudaImageSimilarityEvaluator()
        {
            Dispose();
        }
    }
}

5.2 CudaBinarizeLib.csproj

<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
  <Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
  
  <PropertyGroup>
    <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
    <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
    <ProjectGuid>{8A3F-2E9B-4C7D-1F5A-6B8E9C0D1A2B}</ProjectGuid>
    <OutputType>Library</OutputType>
    <RootNamespace>CudaSharp</RootNamespace>
    <AssemblyName>CudaSharp</AssemblyName>
    <TargetFrameworkVersion>v4.7.2</TargetFrameworkVersion>
    <FileAlignment>512</FileAlignment>
    <Deterministic>true</Deterministic>
    <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
    <PlatformTarget>x64</PlatformTarget>
    <Prefer32Bit>false</Prefer32Bit>
  </PropertyGroup>
  
  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
    <DebugSymbols>true</DebugSymbols>
    <DebugType>full</DebugType>
    <Optimize>false</Optimize>
    <OutputPath>bin\Debug\</OutputPath>
    <DefineConstants>DEBUG;TRACE</DefineConstants>
    <ErrorReport>prompt</ErrorReport>
    <WarningLevel>4</WarningLevel>
    <PlatformTarget>x64</PlatformTarget>
  </PropertyGroup>
  
  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
    <DebugType>pdbonly</DebugType>
    <Optimize>true</Optimize>
    <OutputPath>bin\Release\</OutputPath>
    <DefineConstants>TRACE</DefineConstants>
    <ErrorReport>prompt</ErrorReport>
    <WarningLevel>4</WarningLevel>
    <PlatformTarget>x64</PlatformTarget>
  </PropertyGroup>

  <!-- 引用系统程序集 -->
  <ItemGroup>
    <Reference Include="System" />
    <Reference Include="System.Core" />
    <Reference Include="System.Xml.Linq" />
    <Reference Include="System.Data.DataSetExtensions" />
    <Reference Include="Microsoft.CSharp" />
    <Reference Include="System.Data" />
    <Reference Include="System.Net.Http" />
    <Reference Include="System.Xml" />
  </ItemGroup>

  <!-- 引用 CudaBinarizeLib.cs -->
  <ItemGroup>
    <Compile Include="CudaBinarizeLib.cs" />
    <!-- 如果有其他 C# 源文件,在这里添加 -->
    <!-- <Compile Include="OtherFile.cs" /> -->
  </ItemGroup>

  <!-- 复制 CPP 的 DLL 文件到输出目录 -->
  <ItemGroup>
    <None Include="build\CudaSharpNative.dll">
      <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
    </None>
  </ItemGroup>

  <!-- 构建后将 C# DLL 复制到 build 目录 -->
  <Target Name="CopyDllToBuild" AfterTargets="Build">
    <Copy 
      SourceFiles="$(TargetPath)" 
      DestinationFolder="$(MSBuildProjectDirectory)\build" 
      SkipUnchangedFiles="true" />
    <!-- 同时复制 PDB 调试符号文件(如果存在) -->
    <Copy 
      SourceFiles="$(TargetDir)$(TargetName).pdb" 
      DestinationFolder="$(MSBuildProjectDirectory)\build" 
      Condition="Exists('$(TargetDir)$(TargetName).pdb')" 
      SkipUnchangedFiles="true" />
  </Target>

  <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
</Project>

5.3 Program.cs

using System;
using System.IO;
using CudaSharp;

namespace CudaSharpDemo
{
    class Program
    {
        static void Main(string[] args)
        {
            Console.WriteLine("========================================");
            Console.WriteLine("CUDA 图像二值化 C# 示例程序");
            Console.WriteLine("========================================\n");

            var config = new BinarizeConfig
            {
                Gamma = 1.0f,
                Offset = 0.0f,
                WinRadius = 25,
                SauvolaK = 0.15f,
                UseSauvola = true
            };

            Console.WriteLine("配置参数:");
            Console.WriteLine($"  Gamma: {config.Gamma}");
            Console.WriteLine($"  Offset: {config.Offset}");
            Console.WriteLine($"  WinRadius: {config.WinRadius}");
            Console.WriteLine($"  SauvolaK: {config.SauvolaK}");
            Console.WriteLine($"  UseSauvola: {config.UseSauvola}\n");

            try
            {
                Console.WriteLine("初始化 CUDA 二值化处理器...");
                using (var binarizer = new CudaBinarizer(config))
                {
                    string inputFile = args.Length > 0 ? args[0] : "test.jpg";

                    if (!File.Exists(inputFile))
                    {
                        Console.WriteLine($"错误: 输入文件不存在: {inputFile}");
                        return;
                    }

                    // 测试 Stream 接口
                    Console.WriteLine($"处理文件: {inputFile}");
                    var startTime = DateTime.Now;

                    using (FileStream fsIn = new FileStream(inputFile, FileMode.Open, FileAccess.Read))
                    using (MemoryStream msOut = new MemoryStream())
                    {
                        binarizer.ProcessStream(fsIn, msOut);

                        string outputFile = "csharp_output.jpg";
                        File.WriteAllBytes(outputFile, msOut.ToArray());

                        var elapsed = DateTime.Now - startTime;
                        Console.WriteLine($"✓ 处理成功!");
                        Console.WriteLine($"  输出文件: {outputFile}");
                        Console.WriteLine($"  输出大小: {msOut.Length / 1024} KB");
                        Console.WriteLine($"  处理时间: {elapsed.TotalMilliseconds:F2} ms");
                    }
                }

                Console.WriteLine("\n========================================");
                Console.WriteLine("所有测试完成!");
                Console.WriteLine("========================================");
            }
            catch (Exception ex)
            {
                Console.WriteLine($"\n错误: {ex.Message}");
                Console.WriteLine($"堆栈跟踪:\n{ex.StackTrace}");
            }
        }
    }
}

6. 测试程序

6.1 test_sift.cpp

#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"

#include "gray_image.h"
#include "sift_algorithm.h"
#include "sift_matcher.h"

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <time.h>
#include <float.h>

#define FEATURE_MAX_D 128

typedef struct {
    int idx1;
    int idx2;
    double dist;
} MatchPair;

static GrayImage* load_gray_image(const char* path) {
    int w, h, channels;
    unsigned char* pixels = stbi_load(path, &w, &h, &channels, 0);
    if (!pixels) {
        fprintf(stderr, "Failed to load image: %s\n", path);
        return NULL;
    }
    GrayImage* img = gray_image_create(w, h);
    if (!img) {
        stbi_image_free(pixels);
        return NULL;
    }
    for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
            int idx = (y * w + x) * channels;
            float gray = 0.0f;
            if (channels >= 3) {
                float r = pixels[idx + 0] / 255.0f;
                float g = pixels[idx + 1] / 255.0f;
                float b = pixels[idx + 2] / 255.0f;
                gray = r * 0.299f + g * 0.587f + b * 0.114f;
            } else if (channels == 1) {
                gray = pixels[idx] / 255.0f;
            }
            gray_image_set(img, y, x, gray);
        }
    }
    stbi_image_free(pixels);
    return img;
}

static double descriptor_distance(const Feature* a, const Feature* b) {
    int len = FEATURE_MAX_D;
    double sum = 0.0;
    for (int i = 0; i < len; i++) {
        double diff = a->descr[i] - b->descr[i];
        sum += diff * diff;
    }
    return sqrt(sum);
}

static int match_features(const FeatureArray* f1, const FeatureArray* f2, MatchPair** out_matches) {
    if (!f1 || !f2 || f1->size == 0 || f2->size == 0) return 0;

    int n1 = f1->size;
    int n2 = f2->size;
    double ratioThresh = 0.75;

    MatchPair* matches = (MatchPair*)malloc(n1 * sizeof(MatchPair));
    int match_count = 0;

    for (int i = 0; i < n1; i++) {
        double bestDist = DBL_MAX;
        double secondBestDist = DBL_MAX;
        int bestIdx = -1;

        for (int j = 0; j < n2; j++) {
            double dist = descriptor_distance(f1->data[i], f2->data[j]);
            if (dist < bestDist) {
                secondBestDist = bestDist;
                bestDist = dist;
                bestIdx = j;
            } else if (dist < secondBestDist) {
                secondBestDist = dist;
            }
        }

        if (secondBestDist > 0 && bestDist / secondBestDist < ratioThresh && bestIdx >= 0) {
            matches[match_count].idx1 = i;
            matches[match_count].idx2 = bestIdx;
            matches[match_count].dist = bestDist;
            match_count++;
        }
    }

    *out_matches = matches;
    return match_count;
}

static void set_pixel(unsigned char* img, int w, int h, int x, int y, unsigned char r, unsigned char g, unsigned char b) {
    if (x < 0 || x >= w || y < 0 || y >= h) return;
    int idx = (y * w + x) * 3;
    img[idx + 0] = r;
    img[idx + 1] = g;
    img[idx + 2] = b;
}

static void draw_line(unsigned char* img, int w, int h, int x0, int y0, int x1, int y1, unsigned char r, unsigned char g, unsigned char b) {
    int dx = abs(x1 - x0);
    int dy = abs(y1 - y0);
    int sx = (x0 < x1) ? 1 : -1;
    int sy = (y0 < y1) ? 1 : -1;
    int err = dx - dy;

    while (1) {
        set_pixel(img, w, h, x0, y0, r, g, b);
        if (x0 == x1 && y0 == y1) break;
        int e2 = 2 * err;
        if (e2 > -dy) {
            err -= dy;
            x0 += sx;
        }
        if (e2 < dx) {
            err += dx;
            y0 += sy;
        }
    }
}

static void draw_circle(unsigned char* img, int w, int h, int cx, int cy, int radius, unsigned char r, unsigned char g, unsigned char b) {
    int x = radius;
    int y = 0;
    int err = 0;

    while (x >= y) {
        set_pixel(img, w, h, cx + x, cy + y, r, g, b);
        set_pixel(img, w, h, cx + y, cy + x, r, g, b);
        set_pixel(img, w, h, cx - y, cy + x, r, g, b);
        set_pixel(img, w, h, cx - x, cy + y, r, g, b);
        set_pixel(img, w, h, cx - x, cy - y, r, g, b);
        set_pixel(img, w, h, cx - y, cy - x, r, g, b);
        set_pixel(img, w, h, cx + y, cy - x, r, g, b);
        set_pixel(img, w, h, cx + x, cy - y, r, g, b);
        if (err <= 0) {
            y += 1;
            err += 2 * y + 1;
        }
        if (err > 0) {
            x -= 1;
            err -= 2 * x + 1;
        }
    }
}

static unsigned char* create_side_by_side(const unsigned char* img1, int w1, int h1,
                                          const unsigned char* img2, int w2, int h2,
                                          int* out_w, int* out_h) {
    int max_h = (h1 > h2) ? h1 : h2;
    int out_w_val = w1 + w2;
    int out_h_val = max_h;
    *out_w = out_w_val;
    *out_h = out_h_val;

    unsigned char* out = (unsigned char*)calloc(out_w_val * out_h_val * 3, sizeof(unsigned char));
    if (!out) return NULL;

    // Fill with dark gray background
    for (int i = 0; i < out_w_val * out_h_val * 3; i++) {
        out[i] = 30;
    }

    // Copy img1
    for (int y = 0; y < h1; y++) {
        for (int x = 0; x < w1; x++) {
            int src_idx = (y * w1 + x) * 3;
            int dst_idx = (y * out_w_val + x) * 3;
            out[dst_idx + 0] = img1[src_idx + 0];
            out[dst_idx + 1] = img1[src_idx + 1];
            out[dst_idx + 2] = img1[src_idx + 2];
        }
    }

    // Copy img2
    for (int y = 0; y < h2; y++) {
        for (int x = 0; x < w2; x++) {
            int src_idx = (y * w2 + x) * 3;
            int dst_idx = (y * out_w_val + (x + w1)) * 3;
            out[dst_idx + 0] = img2[src_idx + 0];
            out[dst_idx + 1] = img2[src_idx + 1];
            out[dst_idx + 2] = img2[src_idx + 2];
        }
    }

    return out;
}

static unsigned char* rgb_from_gray(const GrayImage* gray) {
    int w = gray->width;
    int h = gray->height;
    unsigned char* rgb = (unsigned char*)malloc(w * h * 3);
    if (!rgb) return NULL;
    for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
            unsigned char val = (unsigned char)(gray_image_get(gray, y, x) * 255.0f);
            int idx = (y * w + x) * 3;
            rgb[idx + 0] = val;
            rgb[idx + 1] = val;
            rgb[idx + 2] = val;
        }
    }
    return rgb;
}

int main(int argc, char** argv) {
    if (argc < 3) {
        printf("Usage: %s <image1> <image2> [output.jpg]\n", argv[0]);
        return 1;
    }

    const char* img1_path = argv[1];
    const char* img2_path = argv[2];
    const char* output_path = (argc >= 4) ? argv[3] : "sift_match_result.jpg";

    printf("Loading images...\n");
    GrayImage* gray1 = load_gray_image(img1_path);
    GrayImage* gray2 = load_gray_image(img2_path);
    if (!gray1 || !gray2) {
        fprintf(stderr, "Failed to load images\n");
        return 1;
    }

    printf("Image1: %dx%d\n", gray1->width, gray1->height);
    printf("Image2: %dx%d\n", gray2->width, gray2->height);

    printf("Extracting SIFT features from image1...\n");
    FeatureArray* f1 = sift_extract_features(gray1);
    printf("Features1: %d\n", f1 ? f1->size : 0);

    printf("Extracting SIFT features from image2...\n");
    FeatureArray* f2 = sift_extract_features(gray2);
    printf("Features2: %d\n", f2 ? f2->size : 0);

    if (!f1 || !f2 || f1->size == 0 || f2->size == 0) {
        fprintf(stderr, "No features found\n");
        return 1;
    }

    printf("Matching features...\n");
    MatchPair* matches = NULL;
    int match_count = match_features(f1, f2, &matches);
    printf("Matches: %d\n", match_count);

    // Create RGB images for visualization
    unsigned char* rgb1 = rgb_from_gray(gray1);
    unsigned char* rgb2 = rgb_from_gray(gray2);
    if (!rgb1 || !rgb2) {
        fprintf(stderr, "Memory allocation failed\n");
        return 1;
    }

    int out_w, out_h;
    unsigned char* out = create_side_by_side(rgb1, gray1->width, gray1->height,
                                             rgb2, gray2->width, gray2->height,
                                             &out_w, &out_h);
    if (!out) {
        fprintf(stderr, "Memory allocation failed\n");
        return 1;
    }

    // Draw matches
    // Use different colors for different matches
    unsigned char colors[][3] = {
        {255, 0, 0}, {0, 255, 0}, {0, 0, 255},
        {255, 255, 0}, {255, 0, 255}, {0, 255, 255},
        {255, 128, 0}, {128, 255, 0}, {0, 128, 255},
        {255, 0, 128}, {128, 0, 255}, {0, 255, 128}
    };
    int color_count = sizeof(colors) / sizeof(colors[0]);

    // Limit number of drawn matches to avoid clutter
    int draw_count = (match_count > 50) ? 50 : match_count;

    for (int i = 0; i < draw_count; i++) {
        Feature* feat1 = f1->data[matches[i].idx1];
        Feature* feat2 = f2->data[matches[i].idx2];

        int x1 = (int)feat1->x;
        int y1 = (int)feat1->y;
        int x2 = (int)feat2->x + gray1->width;
        int y2 = (int)feat2->y;

        unsigned char* c = colors[i % color_count];
        draw_line(out, out_w, out_h, x1, y1, x2, y2, c[0], c[1], c[2]);
        draw_circle(out, out_w, out_h, x1, y1, 3, c[0], c[1], c[2]);
        draw_circle(out, out_w, out_h, x2, y2, 3, c[0], c[1], c[2]);
    }

    printf("Saving result to %s...\n", output_path);
    int success = stbi_write_jpg(output_path, out_w, out_h, 3, out, 95);
    if (!success) {
        fprintf(stderr, "Failed to write output image\n");
        return 1;
    }

    printf("Done. Result saved to %s\n", output_path);

    free(matches);
    free(rgb1);
    free(rgb2);
    free(out);
    feature_array_free(f1);
    feature_array_free(f2);
    gray_image_free(gray1);
    gray_image_free(gray2);

    return 0;
}

6.2 benchmark_sift.py

#!/usr/bin/env python3
"""
SIFT 性能基准测试:连续运行 10 次,计算平均耗时
"""

import subprocess
import time
import os
import sys

TEST_PAIRS = [
    ("1.jpg", "2.jpg"),
    ("3.jpg", "4.jpg"),
]

RUNS = 10
TEST_SIFT_PATH = "./build/test_sift"

def run_single(pair_idx, img1, img2, run_idx):
    """运行一次 SIFT 测试,返回耗时(秒)"""
    output_file = f"bench_{pair_idx}_{run_idx}.jpg"
    cmd = [TEST_SIFT_PATH, img1, img2, output_file]
    
    start = time.perf_counter()
    result = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    elapsed = time.perf_counter() - start
    
    if result.returncode != 0:
        print(f"[错误] 运行失败: {result.stderr}")
        return None
    
    # 解析输出中的 Matches 数量
    matches = 0
    for line in result.stdout.splitlines():
        if "Matches:" in line:
            try:
                matches = int(line.split(":")[1].strip())
            except:
                pass
    
    return elapsed, matches, output_file

def benchmark_pair(pair_idx, img1, img2):
    """对一对图像运行 10 次测试"""
    print(f"\n{'='*50}")
    print(f"测试组合 {pair_idx + 1}: {img1} & {img2}")
    print(f"{'='*50}")
    
    times = []
    matches_list = []
    last_output = None
    
    for i in range(RUNS):
        print(f"  第 {i+1}/{RUNS} 次运行...", end=" ", flush=True)
        result = run_single(pair_idx, img1, img2, i)
        if result is None:
            continue
        elapsed, matches, output_file = result
        times.append(elapsed)
        matches_list.append(matches)
        last_output = output_file
        print(f"耗时: {elapsed*1000:.2f} ms, 匹配: {matches}")
    
    if not times:
        return None
    
    avg_ms = sum(times) / len(times) * 1000
    min_ms = min(times) * 1000
    max_ms = max(times) * 1000
    avg_matches = sum(matches_list) / len(matches_list)
    
    print(f"\n  统计结果:")
    print(f"    平均耗时: {avg_ms:.2f} ms")
    print(f"    最小耗时: {min_ms:.2f} ms")
    print(f"    最大耗时: {max_ms:.2f} ms")
    print(f"    平均匹配: {avg_matches:.1f}")
    
    return {
        "pair": f"{img1} & {img2}",
        "avg_ms": avg_ms,
        "min_ms": min_ms,
        "max_ms": max_ms,
        "avg_matches": avg_matches,
        "runs": len(times),
        "last_output": last_output
    }

def main():
    print("="*50)
    print("SIFT 性能基准测试")
    print(f"运行次数: {RUNS} 次/组合")
    print(f"测试程序: {TEST_SIFT_PATH}")
    print("="*50)
    
    if not os.path.exists(TEST_SIFT_PATH):
        print(f"[错误] 找不到测试程序: {TEST_SIFT_PATH}")
        sys.exit(1)
    
    results = []
    for idx, (img1, img2) in enumerate(TEST_PAIRS):
        result = benchmark_pair(idx, img1, img2)
        if result:
            results.append(result)
    
    # 清理临时文件,只保留最后一次的输出
    for idx, (img1, img2) in enumerate(TEST_PAIRS):
        for i in range(RUNS):
            f = f"bench_{idx}_{i}.jpg"
            if os.path.exists(f) and f != results[idx]["last_output"]:
                os.remove(f)
    
    # 生成报告
    print("\n" + "="*50)
    print("测试完成!生成报告...")
    print("="*50)
    
    report_lines = [
        "【SIFT 性能基准测试报告】",
        "",
        f"测试配置:连续运行 {RUNS} 次/组合",
        "",
    ]
    
    for r in results:
        report_lines.append(f"组合: {r['pair']}")
        report_lines.append(f"  平均耗时: {r['avg_ms']:.2f} ms")
        report_lines.append(f"  最小耗时: {r['min_ms']:.2f} ms")
        report_lines.append(f"  最大耗时: {r['max_ms']:.2f} ms")
        report_lines.append(f"  平均匹配: {r['avg_matches']:.1f}")
        report_lines.append("")
    
    report = "\n".join(report_lines)
    print(report)
    
    # 保存报告到文件
    with open("benchmark_report.txt", "w", encoding="utf-8") as f:
        f.write(report)
    
    print("\n报告已保存到 benchmark_report.txt")
    print("结果图:")
    for r in results:
        print(f"  {r['last_output']}")

if __name__ == "__main__":
    main()

7. 构建脚本

7.1 build.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CudaSharp - CUDA 动态阈值二值化 DLL 构建脚本
"""

import os
import sys
import subprocess
from pathlib import Path

os.environ['PYTHONIOENCODING'] = 'utf-8'

if sys.platform == 'win32':
    os.environ['PYTHONLEGACYWINDOWSFSENCODING'] = '0'
    import ctypes
    try:
        ctypes.windll.kernel32.SetConsoleCP(65001)
        ctypes.windll.kernel32.SetConsoleOutputCP(65001)
    except:
        pass

VS_PATH = r"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools"
CUDA_ROOT = None
WINDOWS_SDK_ROOT = None
MSVC_VERSION = None

CUDA_ARCHS = [
    ("61", "sm_61"),
    ("75", "sm_75"),
    ("86", "sm_86"),
    ("89", "sm_89"),
]

SOURCE_FILES = [
    "src/CudaSharpNative.cu",
    "src/gray_image.c",
    "src/image_ops.cu",
    "src/sift_types.c",
    "src/sift_algorithm.cu",
    "src/safemem/safemem_host.c",
    "src/safemem/safemem_device.cu",
    "src/safemem/safemem_pool.cu",
    "src/sift_detect.cu",
    "src/sift_matcher.cu",
    "src/image_similarity.cu",
]
OUTPUT_NAME = "CudaSharpNative.dll"
BUILD_DIR = "build"


def to_absolute_path(path):
    if isinstance(path, Path):
        path = str(path)
    return os.path.abspath(path)


def print_header():
    print("=" * 50)
    print("CudaSharp - CUDA DLL 构建脚本")
    print("=" * 50)
    print()


def find_vs_path():
    global VS_PATH
    vcvarsall = Path(VS_PATH) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
    if vcvarsall.exists():
        print(f"[信息] 使用指定的 BuildTools: {VS_PATH}")
        return VS_PATH

    alt_paths = [
        r"C:\Program Files\Microsoft Visual Studio\2022\BuildTools",
        r"C:\Program Files (x86)\Microsoft Visual Studio\2022\Community",
        r"C:\Program Files\Microsoft Visual Studio\2022\Community",
        r"C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools",
    ]

    for path in alt_paths:
        vcvarsall = Path(path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
        if vcvarsall.exists():
            VS_PATH = path
            print(f"[信息] 找到 Visual Studio: {VS_PATH}")
            return VS_PATH

    print("[错误] 找不到 Visual Studio!")
    sys.exit(1)


def get_vcvars_env(vs_path):
    vcvarsall = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
    cmd = f'"{vcvarsall}" x64 && set'

    try:
        result = subprocess.run(
            cmd,
            shell=True,
            capture_output=True,
            text=True,
            encoding='utf-8',
            errors='ignore'
        )

        if result.returncode != 0:
            print("[错误] vcvarsall.bat 执行失败")
            print(result.stderr)
            sys.exit(1)

        env = {}
        for line in result.stdout.splitlines():
            if '=' in line:
                key, value = line.split('=', 1)
                env[key] = value

        return env

    except Exception as e:
        print(f"[错误] 无法获取 VS 环境: {e}")
        sys.exit(1)


def find_windows_sdk():
    global WINDOWS_SDK_ROOT
    if WINDOWS_SDK_ROOT and Path(WINDOWS_SDK_ROOT).exists():
        print(f"[信息] 使用指定的 Windows SDK: {WINDOWS_SDK_ROOT}")
        return WINDOWS_SDK_ROOT

    sdk_paths = [
        r"C:\Program Files (x86)\Windows Kits\10",
        r"C:\Program Files\Windows Kits\10",
    ]

    for path in sdk_paths:
        if Path(path).exists():
            WINDOWS_SDK_ROOT = path
            print(f"[信息] 找到 Windows SDK: {WINDOWS_SDK_ROOT}")
            return WINDOWS_SDK_ROOT

    print("[警告] 未找到 Windows SDK 安装路径!")
    return None


def get_sdk_version(sdk_root):
    if not sdk_root:
        return None

    include_path = Path(sdk_root) / "Include"
    if not include_path.exists():
        return None

    versions = []
    for d in include_path.iterdir():
        if d.is_dir():
            name = d.name
            parts = name.split('.')
            if len(parts) >= 2 and all(p.isdigit() for p in parts):
                versions.append(name)

    if versions:
        versions.sort(key=lambda x: [int(n) for n in x.split('.')])
        latest = versions[-1]
        print(f"[信息] 使用 Windows SDK 版本: {latest}")
        return latest

    return None


def find_cuda():
    global CUDA_ROOT
    if CUDA_ROOT and Path(CUDA_ROOT).exists():
        print(f"[信息] 使用指定的 CUDA: {CUDA_ROOT}")
        return CUDA_ROOT

    if "CUDA_PATH" in os.environ:
        CUDA_ROOT = os.environ["CUDA_PATH"]
        print(f"[信息] 从环境变量获取 CUDA: {CUDA_ROOT}")
        return CUDA_ROOT

    cuda_paths = [
        r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6",
        r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4",
        r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.3",
        r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2",
        r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1",
        r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.0",
        r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8",
    ]

    for path in cuda_paths:
        if Path(path).exists():
            CUDA_ROOT = path
            print(f"[信息] 找到 CUDA: {CUDA_ROOT}")
            return CUDA_ROOT

    print("[错误] 未找到 CUDA 安装路径!")
    sys.exit(1)


def check_nvcc(cuda_root):
    nvcc_path = Path(cuda_root) / "bin" / "nvcc.exe"
    if not nvcc_path.exists():
        print(f"[错误] 找不到 nvcc: {nvcc_path}")
        sys.exit(1)

    try:
        result = subprocess.run(
            [str(nvcc_path), "--version"],
            capture_output=True,
            text=True
        )
        print("[信息] NVCC 版本:")
        print(result.stdout)
        return str(nvcc_path)
    except Exception as e:
        print(f"[错误] 无法运行 nvcc: {e}")
        sys.exit(1)


def find_msvc_version(vs_path):
    global MSVC_VERSION
    if MSVC_VERSION:
        return MSVC_VERSION

    msvc_root = Path(vs_path) / "VC" / "Tools" / "MSVC"
    if not msvc_root.exists():
        return None

    versions = []
    for d in msvc_root.iterdir():
        if d.is_dir():
            if (d / "bin").exists() and (d / "include").exists():
                versions.append(d.name)

    if versions:
        def version_key(v):
            parts = v.split('.')
            return [int(p) for p in parts if p.isdigit()]

        versions.sort(key=version_key, reverse=True)
        MSVC_VERSION = versions[0]
        print(f"[信息] 找到 MSVC 版本: {MSVC_VERSION}")
        return MSVC_VERSION

    return None


def clean_build_dir(build_path, output_name):
    output_file = build_path / output_name
    if output_file.exists():
        try:
            output_file.unlink()
            print(f"[信息] 已删除旧的输出文件: {output_file}")
        except PermissionError:
            print(f"[错误] 无法删除 {output_file},文件可能被占用")
            return False
        except Exception as e:
            print(f"[警告] 删除旧文件时出错: {e}")

    for pattern in ["*.obj", "*.exp", "*.lib", "*.pdb", "*.ilk"]:
        for f in build_path.glob(pattern):
            try:
                f.unlink()
            except:
                pass

    return True


def setup_library_paths(env, vs_path, msvc_ver, sdk_root, sdk_version, cuda_root):
    lib_paths = []

    if msvc_ver:
        msvc_lib = Path(vs_path) / "VC" / "Tools" / "MSVC" / msvc_ver / "lib" / "x64"
        if msvc_lib.exists():
            lib_paths.append(str(msvc_lib))
            print(f"[信息] 添加 MSVC lib (x64): {msvc_lib}")

        atlmfc_lib = Path(vs_path) / "VC" / "Tools" / "MSVC" / msvc_ver / "ATLMFC" / "lib" / "x64"
        if atlmfc_lib.exists():
            lib_paths.append(str(atlmfc_lib))
            print(f"[信息] 添加 ATLMFC lib: {atlmfc_lib}")

    if sdk_root and sdk_version:
        sdk_lib_base = Path(sdk_root) / "Lib" / sdk_version
        um_lib = sdk_lib_base / "um" / "x64"
        if um_lib.exists():
            lib_paths.append(str(um_lib))
            print(f"[信息] 添加 SDK lib (um/x64): {um_lib}")

        ucrt_lib = sdk_lib_base / "ucrt" / "x64"
        if ucrt_lib.exists():
            lib_paths.append(str(ucrt_lib))
            print(f"[信息] 添加 SDK lib (ucrt/x64): {ucrt_lib}")

    if cuda_root:
        cuda_lib = Path(cuda_root) / "lib" / "x64"
        if cuda_lib.exists():
            lib_paths.append(str(cuda_lib))
            print(f"[信息] 添加 CUDA lib (x64): {cuda_lib}")

    if lib_paths:
        current_lib = env.get('LIB', '')
        new_lib = os.pathsep.join(lib_paths)
        if current_lib:
            env['LIB'] = new_lib + os.pathsep + current_lib
        else:
            env['LIB'] = new_lib

        print()
        print("[调试] 最终 LIB 路径:")
        for p in env.get('LIB', '').split(os.pathsep):
            if p.strip():
                exists = "v" if Path(p.strip()).exists() else "x"
                print(f"  [{exists}] {p.strip()}")
        print()

    return env


def build(vs_env, nvcc_path):
    build_path = Path(BUILD_DIR)
    build_path.mkdir(exist_ok=True)

    if not clean_build_dir(build_path, OUTPUT_NAME):
        sys.exit(1)

    for src in SOURCE_FILES:
        s = Path(src)
        if not s.exists():
            print(f"[错误] 找不到源文件: {src}")
            sys.exit(1)

    gencode_args = []
    for arch, code in CUDA_ARCHS:
        gencode_args.extend(["-gencode", f"arch=compute_{arch},code={code}"])

    source_for_build = SOURCE_FILES
    output_for_build = OUTPUT_NAME

    cmd = [
        nvcc_path,
        *source_for_build,
        "--use-local-env",
        "--shared",
        "-o", output_for_build,
        "-O3",
        *gencode_args,
        "--use_fast_math",
        "-Xcompiler", "/O2 /openmp /W4",
        "-I.",
        r"-I.\stb_image",
        "-DNDEBUG"
    ]

    actual_output = build_path / OUTPUT_NAME
    print("[信息] 开始编译...")
    print(f"[信息] 输出文件: {actual_output}")
    print(f"[信息] 命令: {' '.join(cmd)}")
    print()

    env = os.environ.copy()
    env.update(vs_env)

    vs_path = VS_PATH

    vcinstalldir = vs_env.get('VCINSTALLDIR', '')
    if vcinstalldir:
        msvc_root = Path(vcinstalldir) / 'Tools' / 'MSVC'
        if msvc_root.exists():
            versions = sorted([d for d in msvc_root.iterdir() if d.is_dir()], reverse=True)
            if versions:
                cl_bin = versions[0] / 'bin' / 'HostX64' / 'x64'
                if cl_bin.exists():
                    cl_bin_str = str(cl_bin)
                    if cl_bin_str not in env.get('PATH', ''):
                        env['PATH'] = cl_bin_str + os.pathsep + env.get('PATH', '')
                        print(f"[信息] 添加 MSVC 路径: {cl_bin_str}")

    if not any((Path(p) / 'cl.exe').exists() for p in env.get('PATH', '').split(os.pathsep) if p):
        vs_msvc = Path(vs_path) / 'VC' / 'Tools' / 'MSVC'
        if vs_msvc.exists():
            versions = sorted([d for d in vs_msvc.iterdir() if d.is_dir()], reverse=True)
            if versions:
                cl_bin = versions[0] / 'bin' / 'HostX64' / 'x64'
                if cl_bin.exists():
                    env['PATH'] = str(cl_bin) + os.pathsep + env.get('PATH', '')
                    print(f"[信息] 添加 MSVC 路径: {cl_bin}")

    cuda_bin = str(Path(nvcc_path).parent)
    if cuda_bin not in env.get('PATH', ''):
        env['PATH'] = env.get('PATH', '') + os.pathsep + cuda_bin

    msvc_ver = find_msvc_version(vs_path)
    if msvc_ver:
        msvc_include = Path(vs_path) / "VC" / "Tools" / "MSVC" / msvc_ver / "include"
        if msvc_include.exists():
            msvc_include_str = str(msvc_include)
            current_include = env.get('INCLUDE', '')
            if msvc_include_str not in current_include:
                env['INCLUDE'] = msvc_include_str + os.pathsep + current_include
                print(f"[信息] 添加 MSVC include: {msvc_include_str}")

    sdk_root = find_windows_sdk()
    sdk_version = get_sdk_version(sdk_root)

    if sdk_root and sdk_version:
        sdk_include_base = Path(sdk_root) / "Include" / sdk_version
        sdk_subdirs = ['ucrt', 'shared', 'um', 'winrt']
        for subdir in sdk_subdirs:
            subdir_path = sdk_include_base / subdir
            if subdir_path.exists():
                path_str = str(subdir_path)
                current_include = env.get('INCLUDE', '')
                if path_str not in current_include:
                    env['INCLUDE'] = path_str + os.pathsep + current_include
                    print(f"[信息] 添加 SDK include ({subdir}): {path_str}")

    print()
    print("[调试] 最终 INCLUDE 路径:")
    for p in env.get('INCLUDE', '').split(os.pathsep):
        if p.strip():
            exists = "v" if Path(p.strip()).exists() else "x"
            print(f"  [{exists}] {p.strip()}")
    print()

    crtdefs_found = False
    for p in env.get('INCLUDE', '').split(os.pathsep):
        if p.strip():
            crtdefs_path = Path(p.strip()) / "crtdefs.h"
            if crtdefs_path.exists():
                print(f"[信息] 找到 crtdefs.h: {crtdefs_path}")
                crtdefs_found = True
                break

    if not crtdefs_found:
        print("[警告] 未找到 crtdefs.h!编译可能会失败")

    cuda_root = find_cuda()
    env = setup_library_paths(env, vs_path, msvc_ver, sdk_root, sdk_version, cuda_root)

    log_path = build_path / "build.log"
    try:
        # 添加 UTF-8 BOM,确保日志文件被识别为 UTF-8 编码
        with open(log_path, "w", encoding="utf-8-sig") as log_file:
            result = subprocess.run(
                cmd,
                env=env,
                cwd=BUILD_DIR,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                encoding="utf-8",
                errors="ignore"
            )
            log_file.write(result.stdout)
            print(result.stdout, end="")

        if result.returncode != 0:
            print()
            print(f"[错误] 编译失败!详细日志见: {log_path}")
            sys.exit(1)
        else:
            print(f"[信息] 编译日志已保存到: {log_path}")

    except Exception as e:
        print(f"[错误] 编译过程出错: {e}")
        sys.exit(1)


def find_msbuild(vs_path):
    """在 VS 安装目录中查找 MSBuild.exe"""
    msbuild_paths = [
        Path(vs_path) / "MSBuild" / "Current" / "Bin" / "amd64" / "MSBuild.exe",
        Path(vs_path) / "MSBuild" / "Current" / "Bin" / "MSBuild.exe",
        Path(vs_path) / "MSBuild" / "15.0" / "Bin" / "amd64" / "MSBuild.exe",
        Path(vs_path) / "MSBuild" / "15.0" / "Bin" / "MSBuild.exe",
    ]
    for p in msbuild_paths:
        if p.exists():
            return str(p)
    return None


def setup_vs_environment(vs_path):
    """内嵌 VsDevCmd.bat 功能,直接设置 VS 环境变量"""
    import os
    from pathlib import Path

    vs_path = Path(vs_path)

    # 核心环境变量设置(基于 VsDevCmd.bat 逻辑)
    env_vars = {
        "VSINSTALLDIR": str(vs_path) + "/",
        "VCINSTALLDIR": str(vs_path / "VC") + "/",
        "VS170COMNTOOLS": str(vs_path / "Common7" / "Tools") + "/",
        "VisualStudioVersion": "17.0",
        "VSCMD_VER": "17.0",
        "VSCMD_ARG_ARCH": "amd64",
        "VSCMD_ARG_HOST_ARCH": "amd64",
    }

    # 查找 VC 工具版本目录(如 14.40.33807)
    vc_tools_root = vs_path / "VC" / "Tools" / "MSVC"
    vc_tools_version = None
    if vc_tools_root.exists():
        versions = [d for d in vc_tools_root.iterdir() if d.is_dir()]
        if versions:
            # 取最新版本
            vc_tools_version = sorted(versions)[-1].name
            env_vars["VCToolsVersion"] = vc_tools_version
            env_vars["VCToolsInstallDir"] = str(vc_tools_root / vc_tools_version) + "/"

    # 设置 PATH(关键编译工具路径)
    path_additions = []

    # 1. MSBuild 路径
    msbuild_paths = [
        vs_path / "MSBuild" / "Current" / "Bin" / "amd64",
        vs_path / "MSBuild" / "Current" / "Bin",
    ]
    for p in msbuild_paths:
        if p.exists():
            path_additions.append(str(p))
            break

    # 2. VC 编译器路径(Hostx64/x64)
    if vc_tools_version:
        vc_bin = vc_tools_root / vc_tools_version / "bin" / "Hostx64" / "x64"
        if vc_bin.exists():
            path_additions.append(str(vc_bin))

    # 3. Common7 Tools
    common7_tools = vs_path / "Common7" / "Tools"
    if common7_tools.exists():
        path_additions.append(str(common7_tools))

    # 4. Common7 IDE(用于某些工具)
    common7_ide = vs_path / "Common7" / "IDE"
    if common7_ide.exists():
        path_additions.append(str(common7_ide))

    # 应用环境变量
    for key, value in env_vars.items():
        os.environ[key] = value
        print(f"[环境] {key}={value}")

    # 更新 PATH
    if path_additions:
        current_path = os.environ.get("PATH", "")
        new_path = ";".join(path_additions) + ";" + current_path
        os.environ["PATH"] = new_path
        print(f"[环境] PATH 添加: {path_additions}")

    # 设置 INCLUDE 和 LIB(基础路径)
    include_paths = []
    lib_paths = []

    if vc_tools_version:
        vc_tools_dir = vc_tools_root / vc_tools_version
        # INCLUDE
        include_base = vc_tools_dir / "include"
        if include_base.exists():
            include_paths.append(str(include_base))
        # LIB
        lib_base = vc_tools_dir / "lib" / "x64"
        if lib_base.exists():
            lib_paths.append(str(lib_base))

    if include_paths:
        os.environ["INCLUDE"] = ";".join(include_paths)
        print(f"[环境] INCLUDE={include_paths}")

    if lib_paths:
        os.environ["LIB"] = ";".join(lib_paths)
        os.environ["LIBPATH"] = ";".join(lib_paths)
        print(f"[环境] LIB={lib_paths}")

    return True


def build_csharp(vs_path):
    """编译 C# 工程,内嵌环境设置"""
    import subprocess
    import sys
    from pathlib import Path

    print()
    print("=" * 50)
    print("[信息] 开始编译 C# 工程...")
    print("=" * 50)
    print()

    csproj = Path("src/CudaBinarizeLib.csproj")
    if not csproj.exists():
        print(f"[警告] 找不到 C# 工程文件: {csproj},跳过 C# 编译")
        return

    # 内嵌设置 VS 环境
    print("[信息] 设置 Visual Studio 环境变量...")
    setup_vs_environment(vs_path)

    # 查找 MSBuild
    msbuild = find_msbuild(vs_path)
    if not msbuild:
        print("[警告] 找不到 MSBuild.exe,跳过 C# 编译")
        return

    # 直接使用 MSBuild 编译(环境变量已设置)
    cmd = [
        str(msbuild),
        str(csproj.resolve()),
        "/p:Configuration=Release",
        "/p:Platform=AnyCPU",
        "/restore"
    ]

    print(f"[信息] 命令: {cmd}")
    print()

    try:
        result = subprocess.run(cmd, shell=False)
        if result.returncode != 0:
            print()
            print("[错误] C# 工程编译失败!")
            sys.exit(1)
    except Exception as e:
        print(f"[错误] C# 编译过程出错: {e}")
        sys.exit(1)

    print()
    print("[成功] C# 工程编译完成")
    print()

    # 编译完成后,将 CUDA DLL 复制到 C# 输出目录
    cuda_dll = Path(BUILD_DIR) / OUTPUT_NAME
    if cuda_dll.exists():
        csharp_output_dirs = [
            Path("bin") / "x64" / "Release",
            Path("bin") / "Release",
        ]
        copied = False
        for out_dir in csharp_output_dirs:
            if out_dir.exists():
                target = out_dir / OUTPUT_NAME
                try:
                    import shutil
                    shutil.copy2(str(cuda_dll), str(target))
                    print(f"[信息] 已复制 {cuda_dll} -> {target}")
                    copied = True
                except Exception as e:
                    print(f"[警告] 复制 DLL 到 {target} 失败: {e}")
        if not copied:
            print(f"[警告] 未找到 C# 输出目录,请手动将 {cuda_dll} 复制到与 C# 程序同一目录")
    else:
        print(f"[警告] 找不到 CUDA DLL: {cuda_dll}")

def print_success():
    print()
    print("=" * 50)
    print("[成功] 编译完成!")
    print(f"输出文件: {BUILD_DIR}\\{OUTPUT_NAME}")
    print("=" * 50)
    print()
    print("请确保将 DLL 与 C# 程序放在同一目录。")
    print()


def find_nvcc_linux():
    """在 Linux 上查找 nvcc"""
    nvcc = "nvcc"
    try:
        result = subprocess.run(["which", nvcc], capture_output=True, text=True)
        if result.returncode == 0:
            return result.stdout.strip()
    except Exception:
        pass

    cuda_path = os.environ.get("CUDA_PATH", "")
    if cuda_path:
        nvcc_path = os.path.join(cuda_path, "bin", "nvcc")
        if os.path.exists(nvcc_path):
            return nvcc_path

    common_paths = [
        "/usr/local/cuda/bin/nvcc",
        "/usr/local/cuda-13.2/bin/nvcc",
        "/usr/local/cuda-12.6/bin/nvcc",
        "/usr/local/cuda-12.4/bin/nvcc",
        "/usr/local/cuda-12.0/bin/nvcc",
        "/usr/local/cuda-11.8/bin/nvcc",
    ]
    for p in common_paths:
        if os.path.exists(p):
            return p

    print("[错误] 找不到 nvcc,请确保 CUDA Toolkit 已安装并加入 PATH")
    sys.exit(1)


def build_linux():
    """Linux 下编译 CUDA 共享库"""
    build_path = Path(BUILD_DIR)
    build_path.mkdir(exist_ok=True)

    nvcc = find_nvcc_linux()
    print(f"[信息] 使用 NVCC: {nvcc}")
    try:
        result = subprocess.run([nvcc, "--version"], capture_output=True, text=True)
        print("[信息] NVCC 版本:")
        print(result.stdout)
    except Exception as e:
        print(f"[错误] 无法运行 nvcc: {e}")
        sys.exit(1)

    output_name = "libCudaSharpNative.so"

    # 清理旧文件
    old_so = build_path / output_name
    if old_so.exists():
        old_so.unlink()

    # Linux 下使用 -arch=native 自动适配当前 GPU,或回退到常用架构
    gencode_args = ["-arch=native"]

    cmd = [
        nvcc,
        *SOURCE_FILES,
        "--shared",
        "-o", str(build_path / output_name),
        "-O3",
        *gencode_args,
        "--use_fast_math",
        "-Xcompiler", "-fPIC",
        "-I.",
        "-Ilibs/stb_image",
        "-Isrc",
        "-DNDEBUG"
    ]

    print("[信息] 开始编译 CUDA 共享库...")
    print(f"[信息] 命令: {' '.join(cmd)}")
    print()

    log_path = build_path / "build.log"
    try:
        with open(log_path, "w", encoding="utf-8") as log_file:
            result = subprocess.run(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                encoding="utf-8",
                errors="ignore"
            )
            log_file.write(result.stdout)
            print(result.stdout, end="")

        if result.returncode != 0:
            print()
            print(f"[错误] 编译失败!详细日志见: {log_path}")
            sys.exit(1)
        else:
            print(f"[信息] 编译日志已保存到: {log_path}")
    except Exception as e:
        print(f"[错误] 编译过程出错: {e}")
        sys.exit(1)

    print()
    print("=" * 50)
    print(f"[成功] CUDA 共享库编译完成!")
    print(f"输出文件: {BUILD_DIR}/{output_name}")
    print("=" * 50)
    print()

    # 编译 SIFT 测试程序
    test_cpp = Path("tests/test_sift.cpp")
    if test_cpp.exists():
        print("[信息] 编译 SIFT 测试程序...")
        test_out = build_path / "test_sift"
        # 测试程序不需要 CudaSharpNative.cu(DLL 导出层),避免 stb_image 重复定义
        test_sources = [src for src in SOURCE_FILES if "CudaSharpNative.cu" not in src]
        test_cmd = [
            nvcc,
            str(test_cpp),
            *test_sources,
            "-o", str(test_out),
            "-O3",
            *gencode_args,
            "--use_fast_math",
            "-Xcompiler", "-fPIC",
            "-I.",
            "-Ilibs/stb_image",
            "-Isrc",
            "-DNDEBUG"
        ]
        print(f"[信息] 命令: {' '.join(test_cmd)}")
        try:
            result = subprocess.run(
                test_cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                encoding="utf-8",
                errors="ignore"
            )
            print(result.stdout, end="")
            if result.returncode == 0:
                print(f"[成功] 测试程序编译完成: {test_out}")
            else:
                print("[警告] 测试程序编译失败")
        except Exception as e:
            print(f"[警告] 测试程序编译出错: {e}")
        print()


def main():
    print_header()

    if sys.platform == "win32":
        vs_path = find_vs_path()
        print("[信息] 正在设置 Visual Studio 环境...")
        vs_env = get_vcvars_env(vs_path)
        print("[信息] 编译器环境已设置")
        print()
        cuda_root = find_cuda()
        nvcc_path = check_nvcc(cuda_root)
        build(vs_env, nvcc_path)
        print_success()
        build_csharp(vs_path)
        input("按 Enter 键退出...")
    else:
        print("[信息] 检测到 Linux 平台,使用 Linux 编译路径")
        print()
        build_linux()


if __name__ == "__main__":
    main()
posted @ 2026-04-18 21:17  qsBye  阅读(4)  评论(0)    收藏  举报