1 # encoding: utf-8
2 import numpy as np
3 import matplotlib.pyplot as plt
4 import cPickle
5 import gzip
6
7 class SVC(object):
8 def __init__(self, c=1.0, delta=0.001): # 初始化
9 self.N = 0
10 self.delta = delta
11 self.X = None
12 self.y = None
13 self.w = None
14 self.wn = 0
15 self.K = np.zeros((self.N, self.N))
16 self.a = np.zeros((self.N, 1))
17 self.b = 0
18 self.C = c
19 self.stop=1
20 self.k=0
21 self.cls=0
22 self.train_result=[]
23
24 def kernel_function(self,x1, x2): # 核函数
25 return np.dot(x1, x2)
26
27 def kernel_matrix(self, x): # 核矩阵
28 for i in range(0, len(x)):
29 for j in range(i, len(x)):
30 self.K[j][i] = self.K[i][j] = self.kernel_function(self.X[i], self.X[j])
31
32 def get_w(self): # 计算更新w
33 ay = self.a * self.y
34 w = np.zeros((1, self.wn))
35 for i in range(0, self.N):
36 w += self.X[i] * ay[i]
37 return w
38
39 def get_b(self, a1, a2, a1_old, a2_old): # 计算更新B
40 y1 = self.y[a1]
41 y2 = self.y[a2]
42 a1_new = self.a[a1]
43 a2_new = self.a[a2]
44 b1_new = -self.E[a1] - y1 * self.K[a1][a1] * (a1_new - a1_old) - y2 * self.K[a2][a1] * (
45 a2_new - a2_old) + self.b
46 b2_new = -self.E[a2] - y1 * self.K[a1][a2] * (a1_new - a1_old) - y2 * self.K[a2][a2] * (
47 a2_new - a2_old) + self.b
48 if (0 < a1_new) and (a1_new < self.C) and (0 < a2_new) and (a2_new < self.C):
49 return b1_new[0]
50 else:
51 return (b1_new[0] + b2_new[0]) / 2.0
52
53 def gx(self, x): # 判别函数g(x)
54 return np.dot(self.w, x) + self.b
55
56 def satisfy_kkt(self, a): # 判断样本点是否满足kkt条件
57 index = a[1]
58 if a[0] == 0 and self.y[index] * self.gx(self.X[index]) > 1:
59 return 1
60 elif a[0] < self.C and self.y[index] * self.gx(self.X[index]) == 1:
61 return 1
62 elif a[0] == self.C and self.y[index] * self.gx(self.X[index]) < 1:
63 return 1
64 return 0
65
66 def clip_func(self, a_new, a1_old, a2_old, y1, y2): # 拉格朗日乘子的裁剪函数
67 if (y1 == y2):
68 L = max(0, a1_old + a2_old - self.C)
69 H = min(self.C, a1_old + a2_old)
70 else:
71 L = max(0, a2_old - a1_old)
72 H = min(self.C, self.C + a2_old - a1_old)
73 if a_new < L:
74 a_new = L
75 if a_new > H:
76 a_new = H
77 return a_new
78
79 def update_a(self, a1, a2): # 更新a1,a2
80 partial_a2 = self.K[a1][a1] + self.K[a2][a2] - 2 * self.K[a1][a2]
81 if partial_a2 <= 1e-9:
82 print "error:", partial_a2
83 a2_new_unc = self.a[a2] + (self.y[a2] * ((self.E[a1] - self.E[a2]) / partial_a2))
84 a2_new = self.clip_func(a2_new_unc, self.a[a1], self.a[a2], self.y[a1], self.y[a2])
85 a1_new = self.a[a1] + self.y[a1] * self.y[a2] * (self.a[a2] - a2_new)
86 if abs(a1_new - self.a[a1]) < self.delta:
87 return 0
88 self.a[a1] = a1_new
89 self.a[a2] = a2_new
90 self.is_update = 1
91 return 1
92
93 def update(self, first_a): # 更新拉格朗日乘子
94 for second_a in range(0, self.N):
95 if second_a == first_a:
96 continue
97 a1_old = self.a[first_a]
98 a2_old = self.a[second_a]
99 if self.update_a(first_a, second_a) == 0:
100 return
101 self.b= self.get_b(first_a, second_a, a1_old, a2_old)
102 self.w = self.get_w()
103 self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
104 self.stop=0
105
106 def train(self, x, y, max_iternum=100): # SMO算法
107 x_len = len(x)
108 self.X = x
109 self.N = x_len
110 self.wn = len(x[0])
111 self.y = np.array(y).reshape((self.N, 1))
112 self.K = np.zeros((self.N, self.N))
113 self.kernel_matrix(self.X)
114 self.b = 0
115 self.a = np.zeros((self.N, 1))
116 self.w = self.get_w()
117 self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
118 self.is_update = 0
119 for i in range(0, max_iternum):
120 self.stop=1
121 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a))) if x > 0 and x< self.C]
122 if len(data_on_bound) == 0:
123 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]
124 for data in data_on_bound:
125 if self.satisfy_kkt(data) != 1:
126 self.update(data[1])
127 if self.is_update == 0:
128 for data in [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]:
129 if self.satisfy_kkt(data) != 1:
130 self.update(data[1])
131 if self.stop:
132 break
133 return self.w, self.b
134
135 def fit(self,x, y): # 训练模型, 一对一法k(k-1)/2个SVM进行多类分类
136 self.cls, y = np.unique(y, return_inverse=True)
137 self.k=len(self.cls)
138 for i in range(self.k):
139 for j in range(i):
140 a,b=self.sub_data(x,y,i,j)
141 self.train_result.append([i,j,self.train(a,b)])
142
143 def predict(self,x_new): # 预测
144 p=np.zeros(self.k)
145 for i,j,w in self.train_result:
146 self.w=w[0]
147 self.b=w[1]
148 if self.classfy(x_new)==1:
149 p[j]+=1
150 else:
151 p[i]+=1
152 return self.cls[np.argmax(p)]
153
154 def sub_data(self,x,y,i,j): # 数据分类
155 subx=[]
156 suby=[]
157 for a,b in zip(x,y):
158 if b==i:
159 subx.append(a)
160 suby.append(-1)
161 elif b==j:
162 subx.append(a)
163 suby.append(1)
164 return subx,suby
165
166 def classfy(self,x_new): # 预测
167 y_new=self.gx(x_new)
168 cl = int(np.sign(y_new))
169 if cl == 0:
170 cl = 1
171 return cl
172
173
174 def load_data():
175 f = gzip.open('../data/mnist.pkl.gz', 'rb')
176 training_data, validation_data, test_data = cPickle.load(f)
177 f.close()
178 return (training_data, validation_data, test_data)
179
180 if __name__ == "__main__":
181 svc = SVC()
182 np.random.seed(0)
183 l=1000
184 training_data, validation_data, test_data = load_data()
185 svc.fit(training_data[0][:l],training_data[1][:l])
186 predictions = [svc.predict(a) for a in test_data[0][:l]]
187 num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1][:l]))
188 print "%s of %s values correct." % (num_correct, len(test_data[1][:l])) #72/100 #808/1000 #8194/10000(较慢)