pytorch基础练习及螺旋数据分类
2.1 pytorch 基础练习
-
定义数据部分
一般定义数据都使用torch.xxx()。此方法既可以定义一个单独的数,也可以定义向量、矩阵等,非常方便。
![image]()
-
定义操作部分
引入matplotlib库和numpy库后,可以方便的进行绘图操作。
![image]()
2.2 螺旋数据分类
首先,生成3类样本(不能使用简单的线性方法进行分类)。X矩阵存储样本坐标,Y矩阵存储样本种类。
![image]()
接着,尝试用线性模型进行分类。可以看出效果不好,准确率仅为0.504,基本就是随机是否正确。
![image]()
最后,构建两层神经网络进行分类。本质上的不同仅是引入了一个ReLU激活函数,其余部分完全一样,最后准确率有了显著提升,达到了0.949.
![image]()






浙公网安备 33010602011771号