激活函数实现

激活函数实现

1.1:创建激活函数工程目录

创建激活函数目录

   - src          (存放源代码)
   - testbench    (存放测试文件)
   - docs         (存放文档)
   - lut_data     (存放查找表数据)
   - python_utils (存放Python辅助脚本)

1.2:生成SiLU激活函数查找表

创建Python脚本生成LUT数据

#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import os

def sigmoid(x):
    """sigmoid函数"""
    return 1 / (1 + np.exp(-x))

def silu(x):
    """SiLU激活函数 (Swish): x * sigmoid(x)"""
    return x * sigmoid(x)

def generate_silu_lut(bit_width=16, lut_size=256, input_range=8.0):
    """
    生成SiLU激活函数的查找表
    
    参数:
    - bit_width: 数据位宽(默认16位)
    - lut_size: 查找表大小(默认256项)
    - input_range: 输入范围(-input_range到+input_range)
    """
    
    print("="*60)
    print("SiLU激活函数查找表生成器")
    print("="*60)
    print(f"配置参数:")
    print(f"  - 数据位宽: {bit_width} bits")
    print(f"  - LUT大小: {lut_size} entries")
    print(f"  - 输入范围: [{-input_range}, {input_range}]")
    print()
    
    # 创建输入值数组
    x_values = np.linspace(-input_range, input_range, lut_size)
    
    # 计算SiLU值
    silu_values = silu(x_values)
    
    # 量化到固定点表示
    # 使用Q8.8格式(8位整数,8位小数)
    scale_factor = 2**(bit_width//2)
    silu_quantized = np.round(silu_values * scale_factor).astype(int)
    
    # 限制在位宽范围内
    max_val = 2**(bit_width-1) - 1
    min_val = -2**(bit_width-1)
    silu_quantized = np.clip(silu_quantized, min_val, max_val)
    
    return x_values, silu_values, silu_quantized, scale_factor

def save_lut_to_file(lut_data, filename, format='hex'):
    """
    保存查找表到文件
    
    参数:
    - lut_data: 查找表数据
    - filename: 输出文件名
    - format: 输出格式 ('hex', 'bin', 'coe')
    """
    
    print(f"保存LUT到文件: {filename}")
    
    if format == 'hex':
        with open(filename, 'w') as f:
            f.write("// SiLU激活函数查找表 - Hexadecimal格式\n")
            f.write("// 每行一个16位十六进制值\n")
            f.write(f"// 总共 {len(lut_data)} 个条目\n\n")
            
            for i, value in enumerate(lut_data):
                # 转换为无符号表示
                if value < 0:
                    value = (1 << 16) + value
                f.write(f"{value:04X}  // 索引 {i:3d}\n")
                
    elif format == 'coe':
        # Xilinx COE格式(用于Block RAM初始化)
        with open(filename, 'w') as f:
            f.write("; SiLU激活函数查找表 - COE格式\n")
            f.write("; 用于Xilinx Block RAM初始化\n")
            f.write("memory_initialization_radix=16;\n")
            f.write("memory_initialization_vector=\n")
            
            for i, value in enumerate(lut_data):
                if value < 0:
                    value = (1 << 16) + value
                    
                if i < len(lut_data) - 1:
                    f.write(f"{value:04X},\n")
                else:
                    f.write(f"{value:04X};\n")
                    
    elif format == 'bin':
        # 二进制格式
        with open(filename, 'w') as f:
            f.write("// SiLU激活函数查找表 - Binary格式\n")
            for i, value in enumerate(lut_data):
                if value < 0:
                    value = (1 << 16) + value
                f.write(f"{value:016b}  // 索引 {i:3d}\n")
    
    print(f"  文件保存成功!")

def plot_silu_function(x_values, silu_values, silu_quantized, scale_factor):
    """绘制SiLU函数图形"""
    
    plt.figure(figsize=(12, 5))
    
    # 子图1:原始SiLU函数
    plt.subplot(1, 2, 1)
    plt.plot(x_values, silu_values, 'b-', linewidth=2, label='SiLU (float)')
    plt.plot(x_values, silu_quantized/scale_factor, 'r--', 
             linewidth=1, label='SiLU (quantized)')
    plt.grid(True, alpha=0.3)
    plt.xlabel('输入值 x')
    plt.ylabel('SiLU(x) = x * sigmoid(x)')
    plt.title('SiLU激活函数')
    plt.legend()
    
    # 子图2:量化误差
    plt.subplot(1, 2, 2)
    error = silu_values - silu_quantized/scale_factor
    plt.plot(x_values, error * 100, 'g-', linewidth=1)
    plt.grid(True, alpha=0.3)
    plt.xlabel('输入值 x')
    plt.ylabel('误差 (%)')
    plt.title('量化误差分析')
    
    plt.tight_layout()
    plt.savefig('silu_function_plot.png', dpi=150)
    plt.show()
    
    print("\n图形已保存为: silu_function_plot.png")

def generate_testbench_data(num_tests=100):
    """生成测试激励数据"""
    
    print("\n生成测试数据...")
    
    # 生成随机测试输入
    test_inputs = np.random.uniform(-8, 8, num_tests)
    
    # 计算期望输出
    expected_outputs = silu(test_inputs)
    
    # 保存测试数据
    with open('test_vectors.txt', 'w') as f:
        f.write("// SiLU测试向量\n")
        f.write("// 格式: 输入值, 期望输出\n\n")
        
        for inp, out in zip(test_inputs, expected_outputs):
            f.write(f"{inp:8.4f}, {out:8.4f}\n")
    
    print(f"  生成了 {num_tests} 个测试向量")
    print(f"  保存到: test_vectors.txt")

def main():
    """主函数"""
    
    # 设置输出目录
    output_dir = "../lut_data"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"创建输出目录: {output_dir}")
    
    os.chdir(output_dir)
    
    # 生成LUT
    x_vals, silu_float, silu_quant, scale = generate_silu_lut(
        bit_width=16,
        lut_size=256,
        input_range=8.0
    )
    
    # 保存到不同格式的文件
    print("\n保存LUT文件...")
    save_lut_to_file(silu_quant, 'silu_lut.hex', 'hex')
    save_lut_to_file(silu_quant, 'silu_lut.coe', 'coe')
    save_lut_to_file(silu_quant, 'silu_lut.bin', 'bin')
    
    # 生成测试数据
    generate_testbench_data(100)
    
    # 绘制函数图
    print("\n生成可视化图形...")
    plot_silu_function(x_vals, silu_float, silu_quant, scale)
    
    # 打印统计信息
    print("\n统计信息:")
    print(f"  最大量化误差: {np.max(np.abs(silu_float - silu_quant/scale)):.6f}")
    print(f"  平均量化误差: {np.mean(np.abs(silu_float - silu_quant/scale)):.6f}")
    print(f"  查找表大小: {len(silu_quant) * 2} bytes")
    
    print("\n完成!所有文件已生成。")

if __name__ == "__main__":
    main()
  1. 保存文件为:D:\yolo_v10_fpga_project\activation_functions\python_utils\generate_silu_lut.py

运行脚本生成LUT

pip install numpy matplotlib
python generate_silu_lut.py

检查生成的文件

  • 脚本会在 lut_data 文件夹中生成:
    • silu_lut.hex - 十六进制格式
    • silu_lut.coe - Xilinx COE格式
    • silu_lut.bin - 二进制格式
    • test_vectors.txt - 测试向量
    • silu_function_plot.png - 函数图形

1.3:在Vivado中创建SiLU激活函数模块

创建SiLU激活函数Verilog模块

编写SiLU激活函数代码

`timescale 1ns / 1ps
//////////////////////////////////////////////////////////////////////////////////

// Description: 
//   SiLU (Swish) 激活函数硬件实现
//   使用查找表(LUT)和线性插值实现高精度激活
//   SiLU(x) = x * sigmoid(x)
// 
// Dependencies: 
//   - silu_lut.coe文件(查找表初始化数据)
// 
// Revision:
// Revision 1.0 - 初始版本
// Additional Comments:
//   - 支持16位定点数输入输出
//   - 使用256项查找表
//   - 线性插值提高精度
//////////////////////////////////////////////////////////////////////////////////

module silu_activation #(
    // 参数定义
    parameter DATA_WIDTH = 16,        // 数据位宽
    parameter LUT_ADDR_WIDTH = 8,     // LUT地址位宽(256项)
    parameter LUT_DATA_WIDTH = 16,    // LUT数据位宽
    parameter FRAC_BITS = 8,          // 小数位数(Q8.8格式)
    parameter PIPELINE_STAGES = 3     // 流水线级数
)(
    // 时钟和复位
    input wire clk,
    input wire rst_n,
    
    // 控制信号
    input wire enable,                // 模块使能
    input wire bypass,                 // 旁路模式(直通)
    
    // 数据输入
    input wire signed [DATA_WIDTH-1:0] data_in,
    input wire data_valid_in,
    
    // 数据输出
    output reg signed [DATA_WIDTH-1:0] data_out,
    output reg data_valid_out,
    
    // 调试接口
    output wire [31:0] debug_info
);

    // ========================================
    // 内部信号定义
    // ========================================
    
    // LUT存储器信号
    reg [LUT_DATA_WIDTH-1:0] silu_lut_mem [0:(1<<LUT_ADDR_WIDTH)-1];
    
    // 地址生成信号
    wire [LUT_ADDR_WIDTH-1:0] lut_addr;
    wire [FRAC_BITS-1:0] lut_frac;
    wire input_sign;
    wire [DATA_WIDTH-1:0] abs_input;
    
    // 查找表输出
    reg [LUT_DATA_WIDTH-1:0] lut_data_0;
    reg [LUT_DATA_WIDTH-1:0] lut_data_1;
    
    // 插值计算信号
    reg signed [DATA_WIDTH+FRAC_BITS-1:0] interp_result;
    reg signed [DATA_WIDTH-1:0] final_result;
    
    // 流水线寄存器
    reg signed [DATA_WIDTH-1:0] pipe_stage1_data;
    reg signed [DATA_WIDTH-1:0] pipe_stage2_data;
    reg signed [DATA_WIDTH-1:0] pipe_stage3_data;
    reg pipe_stage1_valid;
    reg pipe_stage2_valid;
    reg pipe_stage3_valid;
    
    // 性能计数器
    reg [31:0] activation_count;
    reg [31:0] bypass_count;
    
    // ========================================
    // LUT初始化
    // ========================================
    
    // 初始化查找表(从文件加载)
    initial begin
        // 方法1:使用$readmemh从十六进制文件加载
        $readmemh("silu_lut.hex", silu_lut_mem);
        
        // 方法2:如果使用COE文件,在IP核中配置
        // 这里是备用的硬编码初始化(部分值示例)
        /*
        silu_lut_mem[0]   = 16'h0000;  // SiLU(-8.0)
        silu_lut_mem[1]   = 16'h0001;  // SiLU(-7.94)
        silu_lut_mem[2]   = 16'h0002;  // SiLU(-7.88)
        // ... 更多初始化值
        silu_lut_mem[127] = 16'h0400;  // SiLU(0.0)
        silu_lut_mem[128] = 16'h0800;  // SiLU(0.0)
        // ... 更多初始化值
        silu_lut_mem[255] = 16'h1FFF;  // SiLU(8.0)
        */
        
        $display("SiLU LUT初始化完成,共%d项", 1<<LUT_ADDR_WIDTH);
    end
    
    // ========================================
    // 地址生成逻辑
    // ========================================
    
    // 提取符号位
    assign input_sign = data_in[DATA_WIDTH-1];
    
    // 计算绝对值
    assign abs_input = input_sign ? -data_in : data_in;
    
    // 生成LUT地址和小数部分
    // 假设输入范围是[-8, 8],映射到[0, 255]
    assign lut_addr = abs_input[DATA_WIDTH-2:DATA_WIDTH-1-LUT_ADDR_WIDTH];
    assign lut_frac = abs_input[DATA_WIDTH-2-LUT_ADDR_WIDTH:DATA_WIDTH-1-LUT_ADDR_WIDTH-FRAC_BITS];
    
    // ========================================
    // 流水线第1级:查找表读取
    // ========================================
    
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            lut_data_0 <= {LUT_DATA_WIDTH{1'b0}};
            lut_data_1 <= {LUT_DATA_WIDTH{1'b0}};
            pipe_stage1_data <= {DATA_WIDTH{1'b0}};
            pipe_stage1_valid <= 1'b0;
        end else if (enable) begin
            if (bypass) begin
                // 旁路模式:直接传递输入
                pipe_stage1_data <= data_in;
                pipe_stage1_valid <= data_valid_in;
                bypass_count <= bypass_count + 1;
            end else begin
                // 正常模式:读取LUT
                lut_data_0 <= silu_lut_mem[lut_addr];
                
                // 读取下一个值用于插值(注意边界)
                if (lut_addr < (1<<LUT_ADDR_WIDTH)-1) begin
                    lut_data_1 <= silu_lut_mem[lut_addr + 1];
                end else begin
                    lut_data_1 <= silu_lut_mem[lut_addr];  // 边界处理
                end
                
                pipe_stage1_data <= data_in;
                pipe_stage1_valid <= data_valid_in;
                
                if (data_valid_in) begin
                    activation_count <= activation_count + 1;
                end
            end
        end else begin
            pipe_stage1_valid <= 1'b0;
        end
    end
    
    // ========================================
    // 流水线第2级:线性插值
    // ========================================
    
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            interp_result <= {(DATA_WIDTH+FRAC_BITS){1'b0}};
            pipe_stage2_data <= {DATA_WIDTH{1'b0}};
            pipe_stage2_valid <= 1'b0;
        end else if (enable) begin
            if (bypass) begin
                // 旁路模式
                pipe_stage2_data <= pipe_stage1_data;
                pipe_stage2_valid <= pipe_stage1_valid;
            end else begin
                // 线性插值计算
                // result = lut_data_0 + (lut_data_1 - lut_data_0) * frac / 256
                reg signed [DATA_WIDTH:0] diff;
                reg signed [DATA_WIDTH+FRAC_BITS:0] prod;
                
                diff = lut_data_1 - lut_data_0;
                prod = diff * lut_frac;
                interp_result <= lut_data_0 + (prod >> FRAC_BITS);
                
                pipe_stage2_data <= pipe_stage1_data;
                pipe_stage2_valid <= pipe_stage1_valid;
            end
        end else begin
            pipe_stage2_valid <= 1'b0;
        end
    end
    
    // ========================================
    // 流水线第3级:最终输出
    // ========================================
    
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            final_result <= {DATA_WIDTH{1'b0}};
            pipe_stage3_data <= {DATA_WIDTH{1'b0}};
            pipe_stage3_valid <= 1'b0;
            data_out <= {DATA_WIDTH{1'b0}};
            data_valid_out <= 1'b0;
        end else if (enable) begin
            if (bypass) begin
                // 旁路模式
                data_out <= pipe_stage2_data;
                data_valid_out <= pipe_stage2_valid;
            end else begin
                // 应用符号并输出
                if (pipe_stage2_data[DATA_WIDTH-1]) begin
                    // 负输入的处理(SiLU对称性)
                    final_result <= -interp_result[DATA_WIDTH-1:0];
                end else begin
                    final_result <= interp_result[DATA_WIDTH-1:0];
                end
                
                // 饱和处理
                if (interp_result > {1'b0, {(DATA_WIDTH-1){1'b1}}}) begin
                    data_out <= {1'b0, {(DATA_WIDTH-1){1'b1}}};  // 最大正值
                end else if (interp_result < {1'b1, {(DATA_WIDTH-1){1'b0}}}) begin
                    data_out <= {1'b1, {(DATA_WIDTH-1){1'b0}}};  // 最大负值
                end else begin
                    data_out <= final_result;
                end
                
                data_valid_out <= pipe_stage2_valid;
            end
        end else begin
            data_valid_out <= 1'b0;
        end
    end
    
    // ========================================
    // 性能监控
    // ========================================
    
    always @(posedge clk or negedge rst_n) begin
        if (!rst_n) begin
            activation_count <= 32'd0;
            bypass_count <= 32'd0;
        end
        // 计数逻辑已在上面的流水线中实现
    end
    
    // 调试信息输出
    assign debug_info = {
        8'd0,                           // 保留[31:24]
        lut_addr,                       // LUT地址[23:16]
        activation_count[15:0]          // 激活计数[15:0]
    };
    
    // ========================================
    // 断言和验证(仿真用)
    // ========================================
    
    `ifdef SIMULATION
        // 检查输入有效性
        always @(posedge clk) begin
            if (data_valid_in && enable && !bypass) begin
                if (^data_in === 1'bx) begin
                    $display("ERROR: SiLU输入包含X值 @ %t", $time);
                end
            end
        end
        
        // 监控溢出
        reg overflow_detected;
        always @(posedge clk) begin
            overflow_detected <= (interp_result > {1'b0, {(DATA_WIDTH-1){1'b1}}}) ||
                               (interp_result < {1'b1, {(DATA_WIDTH-1){1'b0}}});
            if (overflow_detected && pipe_stage2_valid) begin
                $display("WARNING: SiLU输出饱和 @ %t, 输入=%h", 
                        $time, pipe_stage2_data);
            end
        end
    `endif

endmodule

创建SiLU激活函数测试平台

tb_silu_activation.v

`timescale 1ns / 1ps
module tb_silu_activation;

    // 参数定义
    parameter DATA_WIDTH = 16;
    parameter CLK_PERIOD = 5;  // 200MHz时钟
    
    // 测试信号
    reg clk;
    reg rst_n;
    reg enable;
    reg bypass;
    reg signed [DATA_WIDTH-1:0] data_in;
    reg data_valid_in;
    
    wire signed [DATA_WIDTH-1:0] data_out;
    wire data_valid_out;
    wire [31:0] debug_info;
    
    // 测试数据存储
    reg signed [DATA_WIDTH-1:0] test_inputs [0:99];
    real expected_outputs [0:99];
    integer test_index;
    
    // DUT实例化
    silu_activation #(
        .DATA_WIDTH(DATA_WIDTH),
        .LUT_ADDR_WIDTH(8),
        .PIPELINE_STAGES(3)
    ) DUT (
        .clk(clk),
        .rst_n(rst_n),
        .enable(enable),
        .bypass(bypass),
        .data_in(data_in),
        .data_valid_in(data_valid_in),
        .data_out(data_out),
        .data_valid_out(data_valid_out),
        .debug_info(debug_info)
    );
    
    // 时钟生成
    initial begin
        clk = 0;
        forever #(CLK_PERIOD/2) clk = ~clk;
    end
    
    // 读取测试向量
    initial begin
        $readmemh("test_inputs.hex", test_inputs);
        // 注意:expected_outputs需要从文件读取或计算
    end
    
    // 主测试序列
    initial begin
        // 初始化波形记录
        $dumpfile("silu_test.vcd");
        $dumpvars(0, tb_silu_activation);
        
        // 显示测试开始
        $display("========================================");
        $display("    SiLU激活函数测试开始");
        $display("========================================");
        $display("时间\t\t操作\t\t输入\t\t输出");
        
        // 初始化信号
        rst_n = 0;
        enable = 0;
        bypass = 0;
        data_in = 0;
        data_valid_in = 0;
        test_index = 0;
        
        // 复位
        #(CLK_PERIOD*10);
        rst_n = 1;
        enable = 1;
        #(CLK_PERIOD*5);
        
        // ========================================
        // 测试1:基本功能测试
        // ========================================
        $display("\n--- 测试1:基本SiLU激活 ---");
        
        // 测试正值
        data_in = 16'h0100;  // 1.0 in Q8.8
        data_valid_in = 1;
        #CLK_PERIOD;
        data_valid_in = 0;
        
        // 等待流水线延迟
        #(CLK_PERIOD*5);
        $display("%t\t正值测试\t%h\t%h", $time, 16'h0100, data_out);
        
        // 测试负值
        data_in = 16'hFF00;  // -1.0 in Q8.8
        data_valid_in = 1;
        #CLK_PERIOD;
        data_valid_in = 0;
        
        #(CLK_PERIOD*5);
        $display("%t\t负值测试\t%h\t%h", $time, 16'hFF00, data_out);
        
        // 测试零值
        data_in = 16'h0000;
        data_valid_in = 1;
        #CLK_PERIOD;
        data_valid_in = 0;
        
        #(CLK_PERIOD*5);
        $display("%t\t零值测试\t%h\t%h", $time, 16'h0000, data_out);
        
        // ========================================
        // 测试2:连续数据流测试
        // ========================================
        $display("\n--- 测试2:连续数据流 ---");
        
        for (test_index = 0; test_index < 10; test_index = test_index + 1) begin
            data_in = test_inputs[test_index];
            data_valid_in = 1;
            #CLK_PERIOD;
        end
        data_valid_in = 0;
        
        // 等待所有数据处理完成
        #(CLK_PERIOD*10);
        
        // ========================================
        // 测试3:旁路模式测试
        // ========================================
        $display("\n--- 测试3:旁路模式 ---");
        
        bypass = 1;
        data_in = 16'h5555;
        data_valid_in = 1;
        #CLK_PERIOD;
        data_valid_in = 0;
        
        #(CLK_PERIOD*5);
        if (data_out == 16'h5555) begin
            $display("%t\t旁路测试\tPASS", $time);
        end else begin
            $display("%t\t旁路测试\tFAIL: 期望=%h, 实际=%h", 
                    $time, 16'h5555, data_out);
        end
        
        bypass = 0;
        
        // ========================================
        // 测试4:边界值测试
        // ========================================
        $display("\n--- 测试4:边界值测试 ---");
        
        // 最大正值
        data_in = 16'h7FFF;
        data_valid_in = 1;
        #CLK_PERIOD;
        data_valid_in = 0;
        #(CLK_PERIOD*5);
        $display("%t\t最大正值\t%h\t%h", $time, 16'h7FFF, data_out);
        
        // 最大负值
        data_in = 16'h8000;
        data_valid_in = 1;
        #CLK_PERIOD;
        data_valid_in = 0;
        #(CLK_PERIOD*5);
        $display("%t\t最大负值\t%h\t%h", $time, 16'h8000, data_out);
        
        // ========================================
        // 测试5:性能测试
        // ========================================
        $display("\n--- 测试5:性能测试 ---");
        
        // 发送100个连续数据
        for (test_index = 0; test_index < 100; test_index = test_index + 1) begin
            data_in = $random;
            data_valid_in = 1;
            #CLK_PERIOD;
        end
        data_valid_in = 0;
        
        // 等待处理完成
        #(CLK_PERIOD*10);
        
        // 读取调试信息
        $display("处理的激活数: %d", debug_info[15:0]);
        
        // ========================================
        // 测试完成
        // ========================================
        #(CLK_PERIOD*10);
        $display("\n========================================");
        $display("    SiLU激活函数测试完成");
        $display("========================================");
        $finish;
    end
    
    // 监控输出
    always @(posedge clk) begin
        if (data_valid_out) begin
            $display("输出: data=%h @ %t", data_out, $time);
        end
    end

endmodule
posted @ 2025-10-04 18:02  李白的白  阅读(14)  评论(0)    收藏  举报