pytorch基础练习及螺旋数据分类

2.1 pytorch 基础练习

  1. 定义数据部分

    一般定义数据都使用torch.xxx()。此方法既可以定义一个单独的数,也可以定义向量、矩阵等,非常方便。
    image

  2. 定义操作部分

    引入matplotlib库和numpy库后,可以方便的进行绘图操作。
    image

    2.2 螺旋数据分类

    首先,生成3类样本(不能使用简单的线性方法进行分类)。X矩阵存储样本坐标,Y矩阵存储样本种类。
    image

    接着,尝试用线性模型进行分类。可以看出效果不好,准确率仅为0.504,基本就是随机是否正确。
    image

    最后,构建两层神经网络进行分类。本质上的不同仅是引入了一个ReLU激活函数,其余部分完全一样,最后准确率有了显著提升,达到了0.949.
    image

posted @ 2021-10-06 21:02  珊瑚花海  阅读(81)  评论(0)    收藏  举报