Sigmoid函数
import numpy as np
import matplotlib.pyplot as plt
def sigmoid(x):
y = 1.0 / (1.0 + np.exp(-x))
return y
plot_x = np.linspace(-10, 10, 100)
plot_y = sigmoid(plot_x)
plt.plot(plot_x, plot_y)
plt.title('Sigmoid')
plt.show()
梯度下降法及学习率分析
import numpy as np
import matplotlib.pyplot as plt
def J(theta): # 损失函数
try:
return (theta-2.5)**2 -1
except:
return float('inf')
def dJ(theta): # 损失函数的导数
return 2 * (theta - 2.5)
def iteration(eta):
theta = 0.0 # 初始点
theta_history = [theta]
epsilon=1e-8
i_iter= 0
n_iters = 450
while i_iter < n_iters:
gradient = dJ(theta)
last_theta = theta
theta = theta - eta * gradient
i_iter += 1
theta_history.append(theta)
if (abs(J(theta) - J(last_theta)) < epsilon):
break # 当两个theta值非常接近的时候,终止循环
return theta_history
if __name__ == '__main__':
plot_x = np.linspace(-1, 6, 141)
eta = [0.01, 0.8, 1.1]
title = ['0.01','0.8','1.1']
for eta, title in zip(eta, title):
theta_history = iteration(eta)
plt.plot(plot_x, J(plot_x), color='r')
plt.plot(np.array(theta_history), J(np.array(theta_history)), color='b', marker='x')
# 设置名称
plt.title(title)
plt.xlabel('theta', fontproperties='simHei', fontsize=15)
plt.ylabel('loss function', fontproperties='simHei', fontsize=15)
plt.savefig('{}.png'.format(title))
plt.clf()
print('When eta={}, total steps of gradient descent is {}'.format(eta, len(theta_history)))
# plt.show()
逻辑回归的实现
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
'''逻辑回归模型'''
class LogisticRegression:
def __init__(self):
self.coef_ = None #维度
self.intercept_ = None #截距
self._theta = None # 截距+权重
# sigmoid函数
def _sigmoid(self,x):
y = 1.0 / (1.0 + np.exp(-x))
return y
'''
X_train: 训练集数据输入x
y_train: 数据x对应的标签输出y
eta: 学习率
n_iters: 迭代总数
'''
def fit(self,X_train,y_train,eta=0.1,n_iters=1e4):
assert X_train.shape[0] == y_train.shape[0], '训练数据集的长度需要与标签长度保持一致'
# 损失函数
def J(theta,X_b,y):
p_predcit = self._sigmoid(X_b.dot(theta))
try:
return - np.sum(y*np.log(p_predcit) + (1-y)*np.log(1-p_predcit)) / len(y)
except:
return float('inf')
# 求sigmoid梯度的导数
def dJ(theta,X_b,y):
x = self._sigmoid(X_b.dot(theta))
return X_b.T.dot(x-y)/len(X_b)
# 梯度下降
def gradient_descent(X_b,y,initial_theta,eta,n_iters=1e4,epsilon=1e-8):
theta = initial_theta
i_iter = 0
while i_iter < n_iters:
gradient = dJ(theta,X_b,y)
last_theta = theta
theta = theta - eta * gradient
i_iter += 1
if (abs(J(theta,X_b,y) - J(last_theta,X_b,y)) < epsilon):
break
return theta
X_b = np.hstack([np.ones((X_train.shape[0],1)),X_train])
initial_theta = np.zeros(X_b.shape[1])
self._theta = gradient_descent(X_b,y_train,initial_theta,eta,n_iters)
self.intercept_ = self._theta[0]
self.coef_ = self._theta[1:]
return self
# 预测概率
def predict_proba(self,X_predict):
X_b = np.hstack([np.ones((X_predict.shape[0], 1)), X_predict])
return self._sigmoid(X_b.dot(self._theta))
# 预测归类
def predict(self,X_predict):
proba = self.predict_proba(X_predict)
return np.array(proba > 0.5,dtype='int')
# 数据集的散点图
def scatter_data(data):
Data = pd.read_csv('ex2data1.txt', sep=',', header=None, names=['test 1', 'test 2', 'Admitted'])
positive = Data[Data['Admitted'] == 1]
negative = Data[Data['Admitted'] == 0]
plt.scatter(positive['test 1'], positive['test 2'], s=30,
c='b', marker='o', label='Admitted')
plt.scatter(negative['test 1'], negative['test 2'], s=30,
c='r', marker='x', label='Not Admitted')
plt.xlabel('test 1 Score')
plt.ylabel('test 2 Score')
#plt.show()
return Data
if __name__ == '__main__':
LR = LogisticRegression()
Data = LR.scatter_data()
cols = Data.shape[1]
LR.fit = (Data.values[:,0:cols-1], Data.values[:,cols-1:])
print(LR.predict(Data.values[:,0:cols-1]))
**ex2data1.txt**文件数据内容
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1