数学之路(3)-机器学习(3)-机器学习算法-SVM[7]
本博客所有内容是原创,未经书面许可,严禁任何形式的转载
http://blog.csdn.net/u010255642
根据SMO的算法描述,用python实现,部分代码如下,定义了一个svm_pmcp类,所有的运算在svm_pmcp完成,这样便于封装和实际应用
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#麦好:myhaspl@qq.com
#http://blog.csdn.net/u010255642
#svm算法
import numpy as np import math
import matplotlib.pyplot as plt
#内积线性核函数
def arraydot(x,y):
return x.T*y
#svm参数与计算类
class svm_pmcp:
def __init__(self):
'''初始化参数变量'''
self.alpha = []
self.samples=[]
self.labels=[]
self.boundalpha=[]
def samples_init(self,samples):
'''样子及乘子参数初始化'''
for (mysp,mylb) in samples:
self.samples.append(mysp)
self.labels.append(mylb)
#初始化拉格朗日乘子alpha为0
for i in xrange(0,len(self.samples)):
self.alpha.append(0)
#初始化b为0
self.b = 0
def kernel_init(self,func):
'''指定核函数'''
self.kernel_func=func
def lagrange_multiplier(self,i):
'''求拉格朗日乘子'''
pass
def svmoutput(self,i):
pass
def tol_init(self,mytol):
self.tol=mytol
def eps_inbit(self,myeps):
self.eps=myeps
def c_init(selfm,myc):
self.c=myc
def choicesecond_max(self,nte):
pass
def choicesecond_random(self):
pass
def get_lh(self,i,j):
pass
def update_b(self):
pass
def update_w(self):
pass
def alpha_nozero_noc(self):
pass
def store_alpha(self,i1,a1,i2,a2):
pass
def takestep(i1,i2,e2,alpha2):
if (i1==i2):
return False
alpha1=lagrange_multiplier(i1)
y1=labels[i1]
e1=svmoutput(i1)-y1
s=y1*y2
l,h=get_lh(i2,i1)
if l==h:
return False
k11=kernel_func(self.samples[i1],self.samples[i1])
k12=kernel_func(self.samples[i1],self.samples[i2])
k13=kernel_func(self.samples[i2],self.samples[i2])
eta=float(2*k12-k11-k22)
if (eta<0):
a2=alpha2-y2*(e1-e2))/eta
if a2<l:
a2=l
elif a2>h:
a2=h
else:
lobj=obfuncl()
hobj=obfunch()
if lobj>hobj+self.eps:
a2=l
elif lobj<hobj-self.eps:
a2=h
else:
a2=alpha2
if abs(a2-alpha2)<self.eps*(a2+alph2+self.eps):
return False
a1=alpha1+s*(alpha2-a2)
update_b()
update_w()
store_alpha(i1,a1,i2,a2)
return True
def examineexample(myi):
y2=labels[myi]
alpha2=lagrange_multiplier(myi)
e2=svmoutput(myi)-y2
r2=e2*y2
if ((r2<-self.tol and alpha2<self.c) or (r2>self.tol and alpha2>0):
if (len(self.boundalpha)>0):
myj=choicesecond_max(e)
if takestep(myj,myi,e2,alpha2):
return 1
else:
myj=choicesecond_random(myi)
if takestep(myj,myi,e2,alpha2):
return 1
return 0
def loop1(self,nc):
for i in xrange(0,len(mysvm.samples)):
nc+=examineexample(i)
def loop2(self,nc):
for i in alpha_nozero_noc():
nc+=examineexample(i)
def mainroutine(self):
numchanged=0
examineall=True
while (numchanged>0 or examineall):
numchanged=0
if examineall:
numchanged=loop1(numchanged)
else:
numchanged=loop2(numchanged)
examineall=not examineall
def mainsvm(mysamples):
mysvm = svm_pmcp()
mysvm.samples_init(mysamples)
mysvm.kernel_init(arraydot)
mysvm.tol_init(0.001)
mysvm.eps_init(0.00001)
mysvm.c_init(1)
mysvm.mainroutine()
后面关于svm的章节将提供类下载地址及调用代码


浙公网安备 33010602011771号