numpy实现k近邻法

import numpy as np 
import matplotlib.pyplot as plt  
w=250
train=np.random.randint(-300,300,(w,4))
train=train.astype(float)
train_lable=np.zeros(w)
def kzhe():
    w1=np.zeros((50,4))
    w1_lable=np.zeros(50)
    for i in range(50):
        w1[i]=train[i+200]
        w1_lable[i]=train_lable[i+200]
    i=3
    for j in range(4):
        train[50*(i+1):50*(i+2),:]=train[50*i:50*(i+1),:]
        train_lable[50*(i+1):50*(i+2)] =train_lable[50*i:50*(i+1)]
        i=i-1
    train[0:50,:]=w1
    train_lable[0:50]=w1_lable  
for i in range(4):
    train[:,i]=(train[:,i]-train[:,i].mean())/train[:,i].std()
for i in range(w):
    if 1*train[i,0]+2*train[i,1]+3*train[i,2]+4*train[i,3]-1>0:
        train_lable[i]=1
    else:
        train_lable[i]=-1
def knn(index,k,sig):
    dis=np.zeros(200)
    disl=[]
    for i in range(200):
        dis[i]=(train[index,0]-train[i,0])*(train[index,0]-train[i,0])+(train[index,1]-train[i,1])*(train[index,1]-train[i,1])+(train[index,2]-train[i,2])*(train[index,2]-train[i,2])+(train[index,3]-train[i,3])*(train[index,3]-train[i,3])
        disl.append(dis[i])
    target=[]
    weight=[]
    for i in range(k):
        target.append(disl.index(min(disl)))
        weight.append(np.exp(-min(disl)**2/(2*sig**2)))
        disl[disl.index(min(disl))]=float('inf')
    l1=0
    l2=0
    for i in range(k):
        if train_lable[target[i]]==1:
            l1=l1+1*weight[i]
        else:
            l2=l2+1*weight[i]
    if l1>l2 :
        return 1
    else:
        return -1

best_score=-1
best_k=0
best_wid=0

for k in range(10):
    for wid in range(20):
        if k!=0 and wid !=0:
            sum2=0
            for j in range(5):
                kzhe()
                for i in range(50):
                    lable=knn(i+200,k,wid)
                    if lable==train_lable[i+200]:
                        sum2=sum2+1
                
            if best_score<sum2*1.0/250:
                best_score=sum2*1.0/250
                best_k=k
                best_wid=wid
            print("准确率",sum2*1.0/250,"k值:",k,"宽度:",wid)
print(best_score,best_k,best_wid)
posted @ 2021-08-15 19:58  祥瑞哈哈哈  阅读(70)  评论(0)    收藏  举报