机器学习—降维-特征选择6-6(局部线性嵌入法/流形降维)

使用局部线性嵌入法(流形降维)对瑞士卷数据集降维

主要步骤流程:

  • 1. 导入包
  • 2. 生成瑞士卷数据集
  • 3. 可视化瑞士卷数据集
  • 4. 不同自变量个数对应的重建误差
  • 5. 可视化LLE降维效果

 

1. 导入包

In [1]:
# 导入包
import numpy as np
import matplotlib.pyplot as plt

 

2. 生成瑞士卷数据集

In [2]:
# 生成瑞士卷数据集
from sklearn import datasets
swiss_roll_dataset =datasets.make_swiss_roll(n_samples=1000)
x = swiss_roll_dataset[0]
y = np.floor(swiss_roll_dataset[1])
print(x.shape)
(1000, 3)
In [3]:
print(y.shape)
(1000,)
In [5]:
y
Out[5]:
array([12.,  7.,  8.,  8.,  5.,  9.,  8., 12.,  7., 10., 12., 13., 10.,
        9.,  6., 10.,  5., 12.,  8.,  5., 12.,  6., 11.,  5.,  6., 10.,
        5., 10.,  7.,  5., 10., 12., 12.,  4.,  5.,  8.,  5.,  9., 11.,
        9., 13., 12., 13., 12., 12.,  6.,  9.,  7.,  5., 13.,  7.,  7.,
       12., 11.,  8.,  6.,  5.,  8., 13., 13.,  5.,  6.,  8., 11., 10.,
        9., 13.,  5., 10.,  8.,  4.,  9., 10.,  5.,  6.,  6., 10.,  7.,
       10.,  5., 10.,  6.,  8., 11.,  9.,  9.,  7.,  7.,  8., 11.,  6.,
       10., 10.,  9., 10., 12.,  9.,  5., 12.,  9., 10.,  8., 12.,  9.,
        8.,  4., 10.,  6., 12.,  5.,  6., 12.,  6.,  5., 12., 12., 12.,
       12.,  9.,  5.,  5., 12.,  4.,  5., 13.,  9., 13., 11.,  5., 12.,
        9., 13.,  5.,  7., 11., 10.,  5., 11., 13.,  5.,  7.,  9.,  8.,
        9.,  9., 13., 13., 13.,  6.,  9., 12., 13., 12.,  5., 10.,  8.,
        7.,  9.,  8., 12.,  6., 12., 13.,  9.,  6.,  9., 11.,  8.,  7.,
        8.,  5.,  4.,  7., 11.,  8., 12.,  8.,  5.,  5.,  6.,  5., 10.,
        7.,  7.,  9.,  5.,  6.,  4., 11.,  9.,  7.,  8.,  9., 13., 10.,
        9.,  9.,  6.,  7., 12.,  5.,  6.,  8.,  5., 11., 14., 11.,  6.,
       11.,  5.,  6.,  8., 12.,  5.,  7., 13.,  6.,  6.,  7.,  5., 10.,
        5.,  7.,  8., 10.,  5.,  7.,  5., 12., 13., 10., 11.,  7.,  6.,
        9.,  6.,  9.,  9., 11., 10.,  6.,  9., 12.,  4., 12., 12.,  9.,
        4.,  6., 11.,  5.,  9., 13., 12.,  8., 11., 10.,  5.,  8., 10.,
        9., 10.,  9.,  8., 10.,  6.,  9., 12., 10., 12.,  6.,  9.,  5.,
        4.,  9., 12., 11., 14., 10.,  7., 12., 10.,  7., 12.,  8., 10.,
       12.,  9.,  7.,  8., 10., 10., 12.,  6., 10.,  9., 13.,  8.,  5.,
       11., 11.,  4., 11.,  5.,  9.,  5.,  8.,  7.,  7., 11.,  8., 10.,
       13.,  6.,  8.,  5.,  9., 13.,  7.,  9., 11.,  7.,  7.,  9., 13.,
        5., 11., 10., 12.,  6., 13.,  5., 11.,  9.,  5.,  8., 12., 10.,
        7.,  7., 10.,  6., 12.,  7.,  9.,  9.,  7., 12.,  8.,  6.,  9.,
       13.,  8., 13.,  9.,  6.,  8., 11.,  9., 10.,  4.,  4.,  7., 12.,
        8.,  6.,  4., 13., 12., 11., 11., 11., 12., 10., 13., 10.,  7.,
        6.,  8., 11.,  7.,  8.,  7., 12.,  9.,  7.,  7., 12., 11.,  7.,
       13.,  4., 11., 12.,  6., 13.,  4., 10.,  9.,  9., 12.,  5.,  9.,
       11.,  7.,  9.,  8., 10., 13.,  4.,  9.,  5.,  5., 10.,  6., 10.,
       11., 11.,  5.,  6.,  8.,  8.,  5.,  7.,  9.,  8.,  8., 11., 11.,
        6., 14.,  8.,  4.,  8.,  5.,  9., 13.,  8.,  9., 10.,  8., 13.,
       13.,  6.,  7.,  8.,  9., 13., 11., 12., 13.,  7.,  8.,  8.,  5.,
       13., 12.,  8., 13.,  7.,  8.,  5., 12.,  8.,  6.,  6., 13.,  8.,
        6., 12., 12., 12., 10.,  6.,  9., 11.,  6., 12.,  8., 10., 12.,
        4., 11.,  9.,  6., 13., 13.,  5., 12.,  7.,  7., 12., 14., 14.,
        7., 10.,  7., 12.,  8.,  6., 14.,  5.,  8.,  7., 13., 13.,  4.,
        7.,  6., 10., 12.,  5.,  9.,  7., 10., 10., 10., 10., 11., 11.,
       14.,  9.,  6., 13., 11., 10.,  7., 13., 12.,  6.,  9., 11.,  6.,
       11., 13., 11., 12.,  4.,  8., 12.,  7.,  5., 11.,  5.,  6.,  5.,
       13.,  7., 12., 12., 12., 12., 12.,  8., 13.,  8.,  4., 10.,  5.,
       10., 13.,  6.,  8., 11., 11.,  7.,  7.,  7., 11., 13., 12., 13.,
       10.,  8.,  7., 10.,  7., 10., 11., 11., 11.,  7., 11.,  9.,  6.,
       13., 13.,  7., 11.,  9., 11.,  6., 12.,  9.,  8., 12., 14., 10.,
        8., 13.,  9., 12., 12.,  8., 11.,  7.,  8.,  9.,  5.,  9.,  9.,
        8., 11., 12., 10.,  9., 10., 10.,  5.,  6., 13., 12., 11., 10.,
       13., 13.,  8., 12., 10., 10., 13.,  7.,  8., 13., 12.,  9., 10.,
        8., 12., 13., 10., 13.,  9.,  9.,  6.,  8.,  9.,  9.,  6.,  9.,
       11.,  8., 12.,  7.,  7.,  9.,  7., 13., 12., 12., 12., 11.,  6.,
       10.,  8.,  5.,  7.,  5.,  7., 13.,  9.,  8., 12.,  6., 11., 12.,
        6., 13., 12., 13., 11.,  8.,  5., 11.,  8., 10., 13., 11., 10.,
        5., 10.,  5., 12.,  7., 13.,  6.,  5., 11.,  9., 13., 13., 10.,
        9., 11., 10., 11.,  8., 14., 13.,  8.,  8., 11.,  6.,  6., 10.,
        7., 11., 11.,  5., 10., 10., 11., 11.,  7.,  5.,  6., 13.,  6.,
        8., 12., 10.,  6.,  7., 12.,  8.,  7.,  8., 13., 12., 11.,  9.,
       10.,  7.,  8.,  9., 13.,  8.,  8.,  7.,  5.,  7., 12., 11.,  6.,
       10., 10., 10., 11., 11.,  8.,  6., 13.,  7., 12.,  8., 10.,  5.,
        7.,  6., 13., 10., 12., 11., 12.,  5.,  9., 12.,  7.,  7.,  9.,
       13., 13.,  8.,  5., 10.,  6.,  6., 13.,  5.,  6., 13.,  5.,  7.,
        6.,  6.,  9., 10.,  6.,  5.,  4., 10., 12.,  8.,  7., 13., 12.,
       12.,  7., 13.,  9., 11., 11.,  7., 12., 10., 11.,  9., 10.,  6.,
        6.,  9.,  6.,  4.,  9.,  7.,  8.,  5., 10., 13., 10.,  9.,  5.,
        9., 11.,  6., 12.,  5., 12., 10.,  7.,  5., 13., 12.,  9., 12.,
        6., 11., 12.,  5.,  7.,  8., 12.,  5.,  9.,  9., 11., 12., 10.,
       12.,  7.,  8., 13.,  6.,  5.,  6., 10.,  5.,  4.,  7., 12., 11.,
        6.,  7.,  8.,  9.,  7., 11., 10.,  8., 10.,  7., 11.,  5., 12.,
       12., 10.,  8., 11., 11., 11.,  6.,  7.,  5., 11., 12.,  7.,  9.,
       11.,  7.,  8.,  5., 11.,  7.,  8.,  8., 11.,  8.,  5.,  6., 11.,
        7.,  4.,  9.,  8.,  9., 11., 12., 10.,  7.,  6.,  8.,  4.,  6.,
       12.,  5., 12., 11.,  7.,  7., 10.,  8.,  7.,  8., 13.,  9., 10.,
       11.,  4.,  9., 12.,  9.,  8.,  5.,  5., 12., 12.,  9.,  8.,  8.,
        5.,  9.,  5.,  6., 10.,  6.,  8., 12., 13., 12., 14.,  7.,  7.,
        7., 12.,  8., 13.,  4., 11., 11.,  9., 10.,  5.,  6.,  9.,  9.,
        4., 13.,  9.,  6.,  9., 11., 10., 11.,  9.,  9., 11., 10., 11.,
       10.,  8.,  6.,  9., 11.,  8.,  9.,  8., 12., 13., 12., 14.])
View Code

 

3. 可视化瑞士卷数据集

In [6]:
# 可视化瑞士卷数据集
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x[:, 0], x[:, 1], x[:, 2],marker='o',c=y)
Out[6]:
<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x23d4b77ff08>
 

4. 不同自变量个数对应的重建误差

In [7]:
# 不同自变量个数对应的重建误差
from sklearn import manifold
def LLE_components(*data):
    X, Y = data
    for n in [3,2,1]:
        lle = manifold.LocallyLinearEmbedding(n_components=n) # LLE算法对应的类是LocallyLinearEmbedding
        lle.fit(X)
        print("n = %d 重建误差:"%n, lle.reconstruction_error_)
In [8]:
LLE_components(x, y)
n = 3 重建误差: 2.647672662461252e-10
n = 2 重建误差: 1.0562086520248735e-12
n = 1 重建误差: 1.8127904025897408e-17

5. 可视化LLE降维效果

In [9]:
# 显示重构后的数据
def LLE_neighbors(*data):
    X, Y = data
    Neighbors = [1, 2, 3, 4, 5, 15, 30, 100, Y.size-1] # 搜索样本的近邻的个数。值越大,则降维后样本的局部关系会保持的更好
    fig = plt.figure("LLE", figsize = (9, 9))
    for i, k in enumerate(Neighbors):
        lle = manifold.LocallyLinearEmbedding(n_components=2,n_neighbors=k,eigen_solver='dense') # 降维到二维,搜索样本紧邻个数,特征分解用dense
        X_r = lle.fit_transform(X) # X_r是降维后的数据
        ax = fig.add_subplot(3,3,i+1) # 3x3的fig图
        ax.scatter(X_r[:,0],X_r[:,1],marker='o',c=Y,alpha=0.5) # 画散点图
        ax.set_title("k = %d"%k) # 设置标题
        plt.xticks(fontsize=10, color="darkorange") # 设置y轴的字体大小及颜色
        plt.yticks(fontsize=10, color="darkorange") # 设置y轴的字体大小及颜色
    plt.suptitle("LLE") #总图标题
    plt.show()
In [10]:
LLE_neighbors(x, y)

 
结论:
3维的瑞士卷数据集被映射到二维

 

posted @ 2022-03-17 00:14  Theext  阅读(443)  评论(0)    收藏  举报