池化层&与激活函数的继承

池化层实现

最大池化模块

max_pooling_2x2.v

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
//  2x2最大池化层实现,支持步长1和2
// 
// 说明:
//   - 支持2x2窗口的最大池化
//   - 可配置步长(stride)
//   - 使用行缓冲器减少内存访问
//   - 流水线设计实现高吞吐率
//////////////////////////////////////////////////////////////////////////////////

module max_pooling_2x2 #(
    parameter DATA_WIDTH = 16,        // 数据位宽
    parameter IMG_WIDTH = 32,         // 输入图像宽度
    parameter IMG_HEIGHT = 32,        // 输入图像高度
    parameter CHANNELS = 16,          // 通道数
    parameter STRIDE = 2,             // 池化步长(1或2)
    parameter USE_PADDING = 0         // 是否使用填充
)(
    // 时钟和复位
    input wire clk,
    input wire rst_n,
    
    // 控制信号
    input wire enable,
    input wire start,                 // 开始新的图像
    
    // 输入接口
    input wire signed [DATA_WIDTH-1:0] pixel_in,
    input wire pixel_valid_in,
    input wire [15:0] pixel_x,        // 当前像素X坐标
    input wire [15:0] pixel_y,        // 当前像素Y坐标
    input wire [15:0] pixel_channel,  // 当前像素通道
    
    // 输出接口
    output reg signed [DATA_WIDTH-1:0] pool_out,
    output reg pool_valid_out,
    output reg [15:0] pool_x,         // 输出X坐标
    output reg [15:0] pool_y,         // 输出Y坐标
    output reg [15:0] pool_channel,   // 输出通道
    
    // 状态输出
    output reg pooling_done,
    output wire [31:0] debug_status
);

    // ========================================
    // 参数计算
    // ========================================
    localparam OUT_WIDTH = (IMG_WIDTH + (USE_PADDING ? 1 : 0)) / STRIDE;
    localparam OUT_HEIGHT = (IMG_HEIGHT + (USE_PADDING ? 1 : 0)) / STRIDE;
    
    // ========================================
    // 内部信号定义
    // ========================================
    
    // 行缓冲器(存储一行数据用于池化)
    (* ram_style = "distributed" *)
    reg signed [DATA_WIDTH-1:0] line_buffer [0:IMG_WIDTH-1];
    
    // 滑动窗口寄存器
    reg signed [DATA_WIDTH-1:0] window [0:1][0:1];  // 2x2窗口
    
    // 位置计数器
    reg [15:0] curr_x;
    reg [15:0] curr_y;
    reg [15:0] curr_channel;
    
    // 输出位置计数器
    reg [15:0] out_x;
    reg [15:0] out_y;
    
    // 窗口有效标志
    reg window_valid;
    reg [1:0] window_fill_count;
    
    // 比较结果
    wire signed [DATA_WIDTH-1:0] max_value;
    
    // 状态机
    reg [2:0] pool_state;
    localparam IDLE = 3'b000;
    localparam FILL_WINDOW = 3'b001;
    localparam COMPUTE = 3'b010;
    localparam OUTPUT = 3'b011;
    localparam DONE = 3'b100;
    
    // ========================================
    // 最大值计算逻辑
    // ========================================
    
    // 4输入最大值比较器(组合逻辑)
    wire signed [DATA_WIDTH-1:0] max_row0;
    wire signed [DATA_WIDTH-1:0] max_row1;
    
    // 第一行最大值
    assign max_row0 = (window[0][0] > window[0][1]) ? window[0][0] : window[0][1];
    
    // 第二行最大值
    assign max_row1 = (window[1][0] > window[1][1]) ? window[1][0] : window[1][1];
    
    // 总体最大值
    assign max_value = (max_row0 > max_row1) ? max_row0 : max_row1;
    
    // ========================================
    // 窗口更新逻辑
    // ========================================
    
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            // 复位所有寄存器
            window[0][0] <= {DATA_WIDTH{1'b0}};
            window[0][1] <= {DATA_WIDTH{1'b0}};
            window[1][0] <= {DATA_WIDTH{1'b0}};
            window[1][1] <= {DATA_WIDTH{1'b0}};
            window_fill_count <= 2'b00;
            window_valid <= 1'b0;
            
        end else if (enable && pixel_valid_in) begin
            // 更新滑动窗口
            // 根据当前像素位置更新对应的窗口位置
            
            if (pixel_x[0] == 0 && pixel_y[0] == 0) begin
                // 左上角
                window[0][0] <= pixel_in;
                window_fill_count <= window_fill_count + 1;
            end else if (pixel_x[0] == 1 && pixel_y[0] == 0) begin
                // 右上角
                window[0][1] <= pixel_in;
                window_fill_count <= window_fill_count + 1;
            end else if (pixel_x[0] == 0 && pixel_y[0] == 1) begin
                // 左下角
                window[1][0] <= pixel_in;
                window_fill_count <= window_fill_count + 1;
            end else begin
                // 右下角
                window[1][1] <= pixel_in;
                window_fill_count <= window_fill_count + 1;
            end
            
            // 检查窗口是否填满
            if (window_fill_count == 2'b11) begin
                window_valid <= 1'b1;
                window_fill_count <= 2'b00;
            end else begin
                window_valid <= 1'b0;
            end
        end
    end
    
    // ========================================
    // 行缓冲器管理
    // ========================================
    
    integer i;
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            for (i = 0; i < IMG_WIDTH; i = i + 1) begin
                line_buffer[i] <= {DATA_WIDTH{1'b0}};
            end
        end else if (enable && pixel_valid_in) begin
            // 将输入像素存入行缓冲器
            if (pixel_x < IMG_WIDTH) begin
                line_buffer[pixel_x] <= pixel_in;
            end
        end
    end
    
    // ========================================
    // 主状态机
    // ========================================
    
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            pool_state <= IDLE;
            curr_x <= 16'd0;
            curr_y <= 16'd0;
            curr_channel <= 16'd0;
            out_x <= 16'd0;
            out_y <= 16'd0;
            pool_out <= {DATA_WIDTH{1'b0}};
            pool_valid_out <= 1'b0;
            pooling_done <= 1'b0;
            
        end else if (enable) begin
            case (pool_state)
                IDLE: begin
                    pooling_done <= 1'b0;
                    if (start) begin
                        curr_x <= 16'd0;
                        curr_y <= 16'd0;
                        curr_channel <= 16'd0;
                        out_x <= 16'd0;
                        out_y <= 16'd0;
                        pool_state <= FILL_WINDOW;
                    end
                end
                
                FILL_WINDOW: begin
                    // 等待窗口填满
                    if (window_valid) begin
                        pool_state <= COMPUTE;
                    end
                end
                
                COMPUTE: begin
                    // 计算最大值
                    pool_out <= max_value;
                    pool_x <= out_x;
                    pool_y <= out_y;
                    pool_channel <= curr_channel;
                    pool_valid_out <= 1'b1;
                    pool_state <= OUTPUT;
                end
                
                OUTPUT: begin
                    pool_valid_out <= 1'b0;
                    
                    // 更新输出坐标
                    if (out_x >= OUT_WIDTH - 1) begin
                        out_x <= 16'd0;
                        if (out_y >= OUT_HEIGHT - 1) begin
                            out_y <= 16'd0;
                            if (curr_channel >= CHANNELS - 1) begin
                                // 所有通道处理完成
                                pool_state <= DONE;
                            end else begin
                                // 下一个通道
                                curr_channel <= curr_channel + 1;
                                pool_state <= FILL_WINDOW;
                            end
                        end else begin
                            out_y <= out_y + 1;
                            pool_state <= FILL_WINDOW;
                        end
                    end else begin
                        out_x <= out_x + 1;
                        pool_state <= FILL_WINDOW;
                    end
                end
                
                DONE: begin
                    pooling_done <= 1'b1;
                    pool_state <= IDLE;
                end
                
                default: pool_state <= IDLE;
            endcase
        end
    end
    
    // ========================================
    // 调试状态输出
    // ========================================
    assign debug_status = {
        13'd0,                  // 保留位[31:19]
        pool_state,            // 状态机状态[18:16]
        curr_channel[7:0],     // 当前通道[15:8]
        out_y[3:0],           // 输出Y坐标[7:4]
        out_x[3:0]            // 输出X坐标[3:0]
    };
    
    // ========================================
    // 仿真调试
    // ========================================
    `ifdef SIMULATION
        always @(posedge clk) begin
            if (pool_valid_out) begin
                $display("池化输出: [%0d,%0d,%0d] = %h @ %t", 
                        pool_x, pool_y, pool_channel, pool_out, $time);
            end
        end
        
        // 监控窗口更新
        always @(posedge clk) begin
            if (window_valid) begin
                $display("池化窗口准备就绪:");
                $display("  [%h, %h]", window[0][0], window[0][1]);
                $display("  [%h, %h]", window[1][0], window[1][1]);
                $display("  最大值 = %h", max_value);
            end
        end
    `endif

endmodule

平均池化模块

avg_pooling_2x2.v

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////

module avg_pooling_2x2 #(
    parameter DATA_WIDTH = 16,
    parameter IMG_WIDTH = 32,
    parameter IMG_HEIGHT = 32,
    parameter CHANNELS = 16,
    parameter STRIDE = 2
)(
    input wire clk,
    input wire rst_n,
    input wire enable,
    
    input wire signed [DATA_WIDTH-1:0] pixel_in,
    input wire pixel_valid_in,
    input wire [15:0] pixel_x,
    input wire [15:0] pixel_y,
    input wire [15:0] pixel_channel,
    
    output reg signed [DATA_WIDTH-1:0] pool_out,
    output reg pool_valid_out,
    output reg [15:0] pool_x,
    output reg [15:0] pool_y,
    output reg [15:0] pool_channel
);

    // 2x2窗口
    reg signed [DATA_WIDTH-1:0] window [0:1][0:1];
    
    // 累加器(需要额外位宽防止溢出)
    reg signed [DATA_WIDTH+1:0] sum;
    
    // 窗口有效计数
    reg [2:0] valid_count;
    
    // 计算平均值
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            sum <= 0;
            pool_out <= 0;
            pool_valid_out <= 0;
            valid_count <= 0;
            
        end else if (enable && pixel_valid_in) begin
            // 更新窗口
            window[pixel_y[0]][pixel_x[0]] <= pixel_in;
            valid_count <= valid_count + 1;
            
            // 当窗口填满时计算平均值
            if (valid_count == 3) begin
                // 计算4个值的和
                sum <= window[0][0] + window[0][1] + 
                      window[1][0] + pixel_in;
                
                // 除以4(右移2位)
                pool_out <= sum[DATA_WIDTH+1:2];
                pool_valid_out <= 1'b1;
                
                // 输出坐标
                pool_x <= pixel_x >> 1;
                pool_y <= pixel_y >> 1;
                pool_channel <= pixel_channel;
                
                valid_count <= 0;
            end else begin
                pool_valid_out <= 1'b0;
            end
        end
    end

endmodule

自适应池化模块(SPP)

adaptive_pooling.v

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
// 自适应池化层,支持SPP(空间金字塔池化)
//////////////////////////////////////////////////////////////////////////////////

module adaptive_pooling #(
    parameter DATA_WIDTH = 16,
    parameter MAX_IMG_SIZE = 64,
    parameter CHANNELS = 256,
    parameter NUM_POOLS = 3          // SPP层数(如1x1, 2x2, 4x4)
)(
    input wire clk,
    input wire rst_n,
    input wire enable,
    
    // 配置接口
    input wire [7:0] input_height,
    input wire [7:0] input_width,
    input wire [2:0] pool_size,      // 1x1, 2x2, 4x4等
    
    // 输入接口
    input wire signed [DATA_WIDTH-1:0] feature_in,
    input wire feature_valid_in,
    input wire feature_last,         // 一帧的最后一个像素
    
    // 输出接口(拼接后的特征)
    output reg signed [DATA_WIDTH-1:0] pool_out,
    output reg pool_valid_out,
    output reg [7:0] pool_index,     // 输出索引
    output reg pool_complete          // 池化完成
);

    // 状态机定义
    reg [2:0] state;
    localparam IDLE = 3'b000;
    localparam POOL_1x1 = 3'b001;
    localparam POOL_2x2 = 3'b010;
    localparam POOL_4x4 = 3'b011;
    localparam CONCAT = 3'b100;
    localparam DONE = 3'b101;
    
    // 不同尺寸的池化缓冲区
    (* ram_style = "distributed" *)
    reg signed [DATA_WIDTH-1:0] pool_buffer_1x1;
    reg signed [DATA_WIDTH-1:0] pool_buffer_2x2 [0:3];
    reg signed [DATA_WIDTH-1:0] pool_buffer_4x4 [0:15];
    
    // 累加器
    reg signed [DATA_WIDTH+3:0] accumulator;
    reg [7:0] pixel_count;
    
    // 池化计算
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            state <= IDLE;
            pool_out <= 0;
            pool_valid_out <= 0;
            pool_complete <= 0;
            accumulator <= 0;
            pixel_count <= 0;
            
        end else if (enable) begin
            case (state)
                IDLE: begin
                    if (feature_valid_in) begin
                        case (pool_size)
                            3'd0: state <= POOL_1x1;
                            3'd1: state <= POOL_2x2;
                            3'd2: state <= POOL_4x4;
                            default: state <= IDLE;
                        endcase
                        pixel_count <= 0;
                        accumulator <= 0;
                    end
                end
                
                POOL_1x1: begin
                    // 全局平均池化
                    if (feature_valid_in) begin
                        accumulator <= accumulator + feature_in;
                        pixel_count <= pixel_count + 1;
                        
                        if (feature_last) begin
                            // 计算平均值
                            pool_buffer_1x1 <= accumulator / pixel_count;
                            state <= CONCAT;
                        end
                    end
                end
                
                POOL_2x2: begin
                    // 2x2区域池化
                    // 实现细节...
                    state <= CONCAT;
                end
                
                POOL_4x4: begin
                    // 4x4区域池化
                    // 实现细节...
                    state <= CONCAT;
                end
                
                CONCAT: begin
                    // 输出拼接后的特征
                    pool_out <= pool_buffer_1x1;  // 简化示例
                    pool_valid_out <= 1'b1;
                    state <= DONE;
                end
                
                DONE: begin
                    pool_valid_out <= 1'b0;
                    pool_complete <= 1'b1;
                    state <= IDLE;
                end
            endcase
        end
    end

endmodule

测试

tb_pooling_layers.v

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
// 池化层测试
//////////////////////////////////////////////////////////////////////////////////

module tb_pooling_layers;

    // 参数定义
    parameter DATA_WIDTH = 16;
    parameter IMG_SIZE = 8;  // 使用小图像便于测试
    parameter CHANNELS = 2;
    
    // 测试信号
    reg clk;
    reg rst_n;
    reg enable;
    reg start;
    
    // 测试数据
    reg signed [DATA_WIDTH-1:0] test_image [0:IMG_SIZE*IMG_SIZE-1];
    reg signed [DATA_WIDTH-1:0] pixel_in;
    reg pixel_valid_in;
    reg [15:0] pixel_x, pixel_y, pixel_channel;
    
    // 最大池化输出
    wire signed [DATA_WIDTH-1:0] max_pool_out;
    wire max_pool_valid;
    wire [15:0] max_pool_x, max_pool_y, max_pool_channel;
    wire max_pooling_done;
    
    // DUT实例化 - 最大池化
    max_pooling_2x2 #(
        .DATA_WIDTH(DATA_WIDTH),
        .IMG_WIDTH(IMG_SIZE),
        .IMG_HEIGHT(IMG_SIZE),
        .CHANNELS(CHANNELS),
        .STRIDE(2)
    ) max_pool_dut (
        .clk(clk),
        .rst_n(rst_n),
        .enable(enable),
        .start(start),
        .pixel_in(pixel_in),
        .pixel_valid_in(pixel_valid_in),
        .pixel_x(pixel_x),
        .pixel_y(pixel_y),
        .pixel_channel(pixel_channel),
        .pool_out(max_pool_out),
        .pool_valid_out(max_pool_valid),
        .pool_x(max_pool_x),
        .pool_y(max_pool_y),
        .pool_channel(max_pool_channel),
        .pooling_done(max_pooling_done)
    );
    
    // 时钟生成
    initial begin
        clk = 0;
        forever #2.5 clk = ~clk;  // 200MHz
    end
    
    // 初始化测试图像
    initial begin
        // 创建一个简单的测试图案
        test_image[0] = 16'h0010;  test_image[1] = 16'h0020;
        test_image[2] = 16'h0030;  test_image[3] = 16'h0040;
        test_image[4] = 16'h0050;  test_image[5] = 16'h0060;
        test_image[6] = 16'h0070;  test_image[7] = 16'h0080;
        // ... 更多测试数据
    end
    
    // 主测试序列
    initial begin
        $dumpfile("pooling_test.vcd");
        $dumpvars(0, tb_pooling_layers);
        
        $display("========================================");
        $display("    池化层测试开始");
        $display("========================================");
        
        // 初始化
        rst_n = 0;
        enable = 0;
        start = 0;
        pixel_in = 0;
        pixel_valid_in = 0;
        pixel_x = 0;
        pixel_y = 0;
        pixel_channel = 0;
        
        #20 rst_n = 1;
        #10 enable = 1;
        
        // 测试最大池化
        $display("\n--- 测试2x2最大池化 ---");
        start = 1;
        #5 start = 0;
        
        // 发送测试图像
        for (integer ch = 0; ch < CHANNELS; ch = ch + 1) begin
            for (integer y = 0; y < IMG_SIZE; y = y + 1) begin
                for (integer x = 0; x < IMG_SIZE; x = x + 1) begin
                    pixel_in = test_image[y*IMG_SIZE + x] + ch*16'h0100;
                    pixel_x = x;
                    pixel_y = y;
                    pixel_channel = ch;
                    pixel_valid_in = 1;
                    #5;
                end
            end
        end
        pixel_valid_in = 0;
        
        // 等待池化完成
        wait(max_pooling_done);
        $display("最大池化完成!");
        
        #50;
        $display("\n========================================");
        $display("    池化层测试完成");
        $display("========================================");
        $finish;
    end
    
    // 监控输出
    always @(posedge clk) begin
        if (max_pool_valid) begin
            $display("最大池化输出[%0d,%0d,%0d] = %h", 
                    max_pool_x, max_pool_y, max_pool_channel, max_pool_out);
        end
    end

endmodule

激活函数和池化层集成

集成模块

activation_pooling_unit.v

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////
// 激活函数和池化层集成单元
//////////////////////////////////////////////////////////////////////////////////

module activation_pooling_unit #(
    parameter DATA_WIDTH = 16,
    parameter IMG_WIDTH = 32,
    parameter IMG_HEIGHT = 32,
    parameter CHANNELS = 16
)(
    input wire clk,
    input wire rst_n,
    
    // 配置接口
    input wire [1:0] activation_type,  // 0:无, 1:ReLU, 2:SiLU
    input wire [1:0] pooling_type,     // 0:无, 1:最大, 2:平均
    input wire [1:0] pool_stride,      // 1或2
    
    // 数据输入
    input wire signed [DATA_WIDTH-1:0] conv_out,
    input wire conv_valid,
    
    // 数据输出
    output wire signed [DATA_WIDTH-1:0] final_out,
    output wire final_valid
);

    // 中间信号
    wire signed [DATA_WIDTH-1:0] act_out;
    wire act_valid;
    
    // 激活函数实例
    silu_activation #(
        .DATA_WIDTH(DATA_WIDTH)
    ) activation_inst (
        .clk(clk),
        .rst_n(rst_n),
        .enable(activation_type == 2'd2),
        .bypass(activation_type != 2'd2),
        .data_in(conv_out),
        .data_valid_in(conv_valid),
        .data_out(act_out),
        .data_valid_out(act_valid)
    );
    
    // 池化层实例
    max_pooling_2x2 #(
        .DATA_WIDTH(DATA_WIDTH),
        .IMG_WIDTH(IMG_WIDTH),
        .IMG_HEIGHT(IMG_HEIGHT),
        .CHANNELS(CHANNELS),
        .STRIDE(2)
    ) pooling_inst (
        .clk(clk),
        .rst_n(rst_n),
        .enable(pooling_type != 2'd0),
        .pixel_in(act_out),
        .pixel_valid_in(act_valid),
        .pool_out(final_out),
        .pool_valid_out(final_valid)
    );

endmodule
posted @ 2025-12-31 18:14  李白的白  阅读(7)  评论(0)    收藏  举报