跟着Leo机器学习:
一个很有趣的个人博客,不信你来撩 fangzengye.com
sklearn框架

函数导图

1.2. Linear and Quadratic Discriminant Analysis
import numpy as np from sklearn.discriminant_analysis import LinearDiscriminantAnalysis X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) y = np.array([1, 1, 1, 2, 2, 2]) clf = LinearDiscriminantAnalysis() clf.fit(X, y)
print(clf.predict([[-0.8, -1]]))
Normal and Shrinkage Linear Discriminant Analysis for classification
import numpy as np import matplotlib.pyplot as pltfrom sklearn.datasets import make_blobs
from sklearn.discriminant_analysis import LinearDiscriminantAnalysisn_train = 20 # samples for training
n_test = 200 # samples for testing
n_averages = 50 # how often to repeat classification
n_features_max = 75 # maximum number of features
step = 4 # step size for the calculationdef generate_data(n_samples, n_features):
"""Generate random blob-ish data with noisy features.This returns an array of input data with shape `(n_samples, n_features)` and an array of `n_samples` target labels. Only one feature contains discriminative information, the other features contain only noise. """</span> X<span class="token punctuation">,</span> y <span class="token operator">=</span> make_blobs<span class="token punctuation">(</span>n_samples<span class="token operator">=</span>n_samples<span class="token punctuation">,</span> n_features<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> centers<span class="token operator">=</span><span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># add non-discriminative features</span> <span class="token keyword">if</span> n_features <span class="token operator">></span> <span class="token number">1</span><span class="token punctuation">:</span> X <span class="token operator">=</span> np<span class="token punctuation">.</span>hstack<span class="token punctuation">(</span><span class="token punctuation">[</span>X<span class="token punctuation">,</span> np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>randn<span class="token punctuation">(</span>n_samples<span class="token punctuation">,</span> n_features <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">return</span> X<span class="token punctuation">,</span> yacc_clf1, acc_clf2 = [], []
n_features_range = range(1, n_features_max + 1, step)
for n_features in n_features_range:
score_clf1, score_clf2 = 0, 0
for _ in range(n_averages):
X, y = generate_data(n_train, n_features)clf1 <span class="token operator">=</span> LinearDiscriminantAnalysis<span class="token punctuation">(</span>solver<span class="token operator">=</span><span class="token string">'lsqr'</span><span class="token punctuation">,</span> shrinkage<span class="token operator">=</span><span class="token string">'auto'</span><span class="token punctuation">)</span><span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> clf2 <span class="token operator">=</span> LinearDiscriminantAnalysis<span class="token punctuation">(</span>solver<span class="token operator">=</span><span class="token string">'lsqr'</span><span class="token punctuation">,</span> shrinkage<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> X<span class="token punctuation">,</span> y <span class="token operator">=</span> generate_data<span class="token punctuation">(</span>n_test<span class="token punctuation">,</span> n_features<span class="token punctuation">)</span> score_clf1 <span class="token operator">+=</span> clf1<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> score_clf2 <span class="token operator">+=</span> clf2<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> acc_clf1<span class="token punctuation">.</span>append<span class="token punctuation">(</span>score_clf1 <span class="token operator">/</span> n_averages<span class="token punctuation">)</span> acc_clf2<span class="token punctuation">.</span>append<span class="token punctuation">(</span>score_clf2 <span class="token operator">/</span> n_averages<span class="token punctuation">)</span>features_samples_ratio = np.array(n_features_range) / n_train
plt.plot(features_samples_ratio, acc_clf1, linewidth=2,
label="Linear Discriminant Analysis with shrinkage", color='navy')
plt.plot(features_samples_ratio, acc_clf2, linewidth=2,
label="Linear Discriminant Analysis", color='gold')plt.xlabel('n_features / n_samples')
plt.ylabel('Classification accuracy')
plt.legend(loc=1, prop={'size': 12})
plt.suptitle('Linear Discriminant Analysis vs.
shrinkage Linear Discriminant Analysis (1 discriminative feature)')
plt.show()
核心函数:
X, y = generate_data(n_train, n_features)clf1 <span class="token operator">=</span> LinearDiscriminantAnalysis<span class="token punctuation">(</span>solver<span class="token operator">=</span><span class="token string">'lsqr'</span><span class="token punctuation">,</span> shrinkage<span class="token operator">=</span><span class="token string">'auto'</span><span class="token punctuation">)</span><span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> clf2 <span class="token operator">=</span> LinearDiscriminantAnalysis<span class="token punctuation">(</span>solver<span class="token operator">=</span><span class="token string">'lsqr'</span><span class="token punctuation">,</span> shrinkage<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> X<span class="token punctuation">,</span> y <span class="token operator">=</span> generate_data<span class="token punctuation">(</span>n_test<span class="token punctuation">,</span> n_features<span class="token punctuation">)</span> score_clf1 <span class="token operator">+=</span> clf1<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> score_clf2 <span class="token operator">+=</span> clf2<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span>
浙公网安备 33010602011771号