【rust】《Rust深度学习[4]-理解线性网络(Candle)》

全连接/线性

在神经网络中,全连接层,也称为线性层,是一种层,其中来自一层的所有输入都连接到下一层的每个激活单元。在大多数流行的机器学习模型中,网络的最后几层是完全连接的。实际上,这种类型的层执行基于在先前层中学习的特征输出类别预测的任务。

全连接层的示例,具有四个输入节点和八个输出节点。

全连接层在输入中接收在先前卷积层中激活的节点向量。这个向量在被发送到输出层之前,会经过一个或多个密集层。在到达输出层之前,使用激活函数进行预测。虽然卷积层和池化层通常使用ReLU函数,但基于分类问题的类型,全连接层可以使用两种类型的激活函数:

  Sigmoid:逻辑函数,用于二进制分类问题。

  Softmax:一个更广义的逻辑激活函数,它确保输出层中的值总和为1。通常用于多类分类。

依赖

[package]
name = "mnist-ml-linear"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
# 使用 cargo add candle-core@0.4.1 下载
candle-core = "0.4.1"

代码

use candle_core::{DType, Device, Result, Tensor};

// 定义线性层
struct Linear {
    // 权重
    weight: Tensor,
    // 偏移量
    bias: Tensor,
}

// 线性层函数
impl Linear {
    fn forward(
        &self,
        x: &Tensor,
    ) -> Result<Tensor> {
        let x = x
            .contiguous()?
            // 将输入值乘以权重
            .matmul(&self.weight.contiguous()?)?;
        // 再加上偏移量
        x.broadcast_add(&self.bias)
    }
}

// 模型
struct Model {
    // 线性第一层
    first: Linear,
    // 线性第二层
    second: Linear,
}

// 模型函数
impl Model {
    fn forward(
        &self,
        image: &Tensor,
    ) -> Result<Tensor> {
        // 传入图片进行第一层线性分析
        let x = self.first.forward(image)?;
        // 将其与 ReLU 激活函数相乘
        let x = x.relu()?;
        // 进行第二层线性分析
        self.second.forward(&x)
    }
}

// 模仿线性模型分析过程
fn main() -> Result<()> {
    //  Device::new_cuda(0)?;
    //  Device::Cpu;
    // 使用CPU资源
    let device = Device::Cpu;

    // 创建demo(线性第一层)
    let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
    let bias = Tensor::zeros((100,), DType::F32, &device)?;
    let first = Linear {
        weight,
        bias,
    };
    // 创建demo(线性第二层)
    let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
    let bias = Tensor::zeros((10,), DType::F32, &device)?;
    let second = Linear {
        weight,
        bias,
    };
    // 创建模型
    let model = Model {
        first,
        second,
    };

    // demo图片
    let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;

    // 开始模型推演
    let digit = model.forward(&dummy_image)?;

    println!("Digit {digit:?} digit");
    Ok(())
}

 

posted @ 2024-04-23 13:41  芋白  阅读(567)  评论(0)    收藏  举报