统计学习-感知机
感知机对应于输入空间中实例划分为正负两类的分离超平面, 属于判别模型. 其目标是求出将训练数据进行线性划分的分离超平面. 是神经网络与去过向量机的基础
定义: 假设输入空间(特征空间)是\(\mathcal{X} \subseteq \mathbf{R}^n\), 输出空间是\(\mathcal{Y} = \{+1, -1\}\), 输入\(x \in \mathcal{X}\)是一个输入实例的特征向量, 对应于输入空间的点, 实例输出\(y \in \mathcal{Y}\)表示实例的类别, 由输入空间到输出空间的函数\begin{align} f(x) = \text{sign} (w \cdot x + b) \end{align}称为感知机. 其中\(w, b\)为感知机的模型参数, \(w \in \mathbf{R}^n\)叫作可权值或者权值向量, \(b \in \mathcal{R}\)叫作偏置.
感知机是一种线性分类模型, 属于判别模型. 其假设空间是定义在特征空间中的所有线性分类模型或者线性分类器\(\{f|f(x) = w \cdot x + b\}\). 方程 \begin{align} w \cdot x + b = 0 \end{align}对应于特征空间\(\mathcal{R}^n\)中的一个超平面\(S\), 其中\(w\)是超平面的法向量, \(b\)是超平面的截距. 这个超平面将特征空间划分为两个部分 位于两分的点分别被分为正负两类, 因此被称为分离超平面.
感知机的损失函数采用的是误分类点\(x_i \in M\)到超平面\(S\)的总距离: \begin{align} L(w, b) = - \frac{1}{|w|} \sum_{x_i \in M} y_i (w \cdot x_i + b) \end{align}, 最终目标转化求解\(L(w, b)\)的最小值: \begin{align} \min_{w,b} L(w, b) = - \sum_{x_i \in M} y_i(w_i \cdot x_i + b) \end{align}. 首先选取一个超平面\((w_0, b_0)\), 采用梯度下降法不断极小化目标函数, 极小化过程中不是一次使用\(M\)中氖误分类点的梯度下降, 而是一次随机选取一个误分类点使其梯度下降. 根据损失函数\(L(w, b)\)的梯度 \begin{align} \nabla L(w, b) = (-\sum_{x_i \in M} y_i x_i, -\sum_{x_i \in M} y_i) \end{align}, 采用梯度下降法, 不断对\(w, b\)进行更新: \begin{align} w &\gets w + \eta y_i x_i \\ b &\gets b + \eta y_i \end{align}. 其中\(\eta \in (0, 1]\)是步长, 也称学习率, 直到\(L(w, b)\)减到最小.
梯度下降法: 对于目标函数\(f(x_1, x_2, \cdots, x_n)\), 其梯度为\begin{align} \nabla f = (\frac{\partial f}{\partial x_1}, \frac{\partial f}{\partial x_2}, \cdots, \frac{\partial f}{\partial x_n})^T \end{align}. 在梯度反方向增加\(\mathbf{x}\), 这个时候函数的值下降梯度最大.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# @author acrazing - joking.young@gmail.com
# @version 1.0.0
# @since 2017-05-23 22:30:01
#
# c2_perceptron.py
#
import random
import numpy
from study.utils.utils import err, debug
def train(x, y, rate=1, max_loop=10):
"""
:type x: numpy.ndarray
:type y: numpy.ndarray
:param rate:
:param max_loop:
:return:
"""
rows, dims = x.shape
w = numpy.zeros(shape=dims, dtype='float')
b = 0.0
r = range(rows)
max_times = max_loop * rows
while max_times > 0:
count = 0
max_times -= 1
for i in r:
value = (numpy.inner(w, x[i]) + b) * y[i]
if value > 0:
continue
err('error row: %s, y: %s, value: %s, x: %s, w: %s, b: %s' % (i, y[i], value, x[i], w, b))
count += 1
while True:
w = w + rate * y[i] * x[i]
b = b + rate * y[i]
value = (numpy.inner(w, x[i]) + b) * y[i]
if value > 0:
break
if count == 0:
break
if max_times == 0:
raise RuntimeError()
return w, b
def predict(x, w, b):
"""
:type x: numpy.ndarray
:type w: numpy.ndarray
:type b: float
:return:
"""
return -1 if numpy.inner(w, x) + b < 0 else 1
def test_perceptron(rows=5, dims=2, rate=1, max_loop=10):
x = numpy.random.randint(0, 5, (rows, dims)) * 1.0
""":type x: numpy.ndarray"""
w0 = numpy.random.randint(1, 5, dims) * 1.0
b0 = random.randint(-15, -10) * 1.0
y = numpy.array([-1 if numpy.inner(w0, r) + b0 < 0 else 1 for r in x])
err('expected w: %s, b: %s' % (w0, b0))
err('train x:\n%s\ntrain y:\n%s' % (x, y))
w, b = train(x, y, rate, max_loop)
err('trained w: %s, b: %s' % (w, b))
debug(globals(), __name__)
算法的收敛性(Novikoff): 假设输入数据\(T_N\)是线性可分的, 则
- 存在一个分离超平面\(\hat{w}_{opt} = (w_{opt}, b_{opt})\)及常数\(\gamma > 0\), 使得对任意\(i = 1, 2, \cdots, N\), 使得 \begin{align} y_i(\hat{w}_{opt} \cdot \hat{x}_i) = y_i (w_{opt} \cdot x_{i} + b_{opt}) \geqslant \gamma \end
- 令\(R = \max_{1 \leqslant i \leqslant N} \| \hat{x}_i \|\), 则算法迭代次数\(k\)满足 \begin{align} k \leqslant \left( \frac{R}{\gamma} \right)^2 \end
如果数据集线性不可分, 则感知机学习算法不收敛, 迭代结果会发生震荡.
对偶形式
在原始形式的迭代过程中, 通过误分类点对\(w, b\)进行更新, 假定对第\(i\)个点\((x_i, y_i)\)更新了\(n_i\)次, 则最终\(w, b\)因为这个点的增量为\(\delta w_i = \alpha_i y_i x_i, \delta b_i = \alpha_i y_i, \alpha_i = n_i \eta\). 如果假定初始的\((w, b)\)为\((0, 0)\), 则最终的\(w, b\)分别为 \begin{align} w &= \sum_{i = 1}^N \alpha_i y_i x_i \\ b &= \sum_{i=1}^N \alpha_i y_i \end{align}. 这意味着, 误分类的点\((x_i, y_i)\)应该满足 \begin{align} y_i(w \cdot x_i + b) = y_i \left(\sum_{j=1}^N \alpha_j y_j x_j \cdot x_i + b\right) \leqslant 0 \end{align}, 在迭代过程中, 只需要更新两个标量\(\alpha_i, b\)即可: \begin{align} \alpha_i &\gets \alpha_i + \eta \\ b &\gets b + \eta y_i \end{align}. 而\(x_j \cdot x_i\)可以进行缓存, 得到一个\(N \times N\)的矩阵, 也就是Gram矩阵.
似乎并没有简化计算, 虽然在更新这一步将复杂度由\(O(n)\)降到了\(O(1)\), 但是在检查是否是误分类点时复杂度由\(O(n)\)增加至了\(O(N)\), 如果\(N \gg n\), 这是不合理的, 同时还额外占用了\(O(N^2)\)的空间.

浙公网安备 33010602011771号