肉身布莱特-感知机python实现

一、代码:做线性拟合

import numpy as np
from matplotlib import pyplot as plt
# 很简单的一个体积增大和毒气量的关系拟合
virulence = np.random.rand(100)
volume = np.random.rand(100)
virulence.sort()
volume.sort()
 
plt.scatter(volume, virulence)
 
w = 0
b = 0
alpha = 0.05
num_iterations = 1000
 
for i in range(num_iterations):
    for j in range(len(virulence)):
        x = volume[j]
        y = virulence[j]
        prediction = w * x + b
        loss = (y - prediction) ** 2 # MSE
 
        w -= alpha *-2 * x * (y - prediction) # 负梯度移动才会找到最小值
        b -= alpha * -2 * (y - prediction)
print(loss)
plt.plot(volume, volume * w + b)
plt.show()
 

二、代码:二元分类。当然也可用svm来解决

import numpy as np
 
class Perceptron:
    def __init__(self):
        self.w = None
        self.b = None
        
    def train(self, X, y, lr=0.1, epochs=10):
        """
        训练感知机模型
        
        参数:
        X -- 输入数据,一个形状为 (m, n) 的二维数组,每行表示一个样本
        y -- 标签数据,一个形状为 (m,) 的一维数组,每个元素为 1 或 -1
        lr -- 学习率,用于控制每次迭代的步长,默认为 0.1
        epochs -- 迭代次数,默认为 10
        
        返回值:
        无
        """
        m, n = X.shape
        self.w = np.zeros(n)
        self.b = 0
        
        for epoch in range(epochs):
            for i in range(m):
                xi = X[i]
                yi = y[i]
                if yi * (np.dot(xi, self.w) + self.b) <= 0:
                    self.w += lr * yi * xi
                    self.b += lr * yi
                    
    def predict(self, X):
        """
        预测数据的标签
        
        参数:
        X -- 输入数据,一个形状为 (m, n) 的二维数组,每行表示一个样本
        
        返回值:
        预测结果,一个形状为 (m,) 的一维数组,每个元素为 1 或 -1
        """
        y_pred = np.dot(X, self.w) + self.b
        y_pred = np.where(y_pred > 0, 1, -1)
        return y_pred
 
import matplotlib.pyplot as plt
import numpy as np
from test import Perceptron
 
# 创建二分类数据
X = np.array([[3, 3], [4, 3], [1, 1]])
y = np.array([1, 1, -1])
 
# 训练模型
perceptron = Perceptron()
perceptron.train(X, y)
 
# 预测结果
y_pred = perceptron.predict(X)
 
# 绘制数据和分类结果
plt.scatter(X[:, 0], X[:, 1], c=y_pred)
plt.show()
posted @ 2023-05-08 17:33  cccjjh  阅读(26)  评论(0)    收藏  举报