感知机学习算法的原始形式 | 机器学习随笔

根据《统计学习方法 第二版》中2.3.1部分,使用python编写了感知机学习算法的原始形式,具体代码如下:

import numpy as np

def perceptron_raw():
    data = np.array(
        [
            [3, 3, 1],
            [4, 3, 1],
            [1, 1, -1],
        ]
    )
    X_name = ['x'+str(x+1) for x in range(len(data))]
    X = data[:, :-1]
    Y = data[:, -1]
    
    w = np.matrix([0, 0])
    b = 0
    rate = 1.

    cons = [False]
    i = 0
    while not all(cons):
        cons = []
        for x_name, x, y in zip(X_name, X, Y):
            con = y*(w @ np.matrix(x).T+b)
            if con[0,0] <= 0:
                cons.append(False)
                w = w + rate*y*x
                b = b + rate*y
                i += 1
                print('迭代次数:{}, 误分类点:{}, w:{}, b:{}'.format(i, x_name, w.tolist()[0], b))
                break
            else:
                cons.append(True)
    print('迭代次数:{}, 误分类点:{}, w:{}, b:{}'.format(i+1, 0, w.tolist()[0], b))


if __name__ == '__main__':
    perceptron_raw()

    print('end the program.')

运行结果一致,最终输出如下:

迭代次数:1, 误分类点:x1, w:[3.0, 3.0], b:1.0
迭代次数:2, 误分类点:x3, w:[2.0, 2.0], b:0.0
迭代次数:3, 误分类点:x3, w:[1.0, 1.0], b:-1.0
迭代次数:4, 误分类点:x3, w:[0.0, 0.0], b:-2.0
迭代次数:5, 误分类点:x1, w:[3.0, 3.0], b:-1.0
迭代次数:6, 误分类点:x3, w:[2.0, 2.0], b:-2.0
迭代次数:7, 误分类点:x3, w:[1.0, 1.0], b:-3.0
迭代次数:8, 误分类点:0, w:[1.0, 1.0], b:-3.0
end the program.

posted on 2022-10-30 21:10  里斯斯里  阅读(51)  评论(0)    收藏  举报

导航