用随机森林做分类

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.fft as fft
df = pd.read_csv('train.csv')
df=df.drop(['ID'],axis=1)
nmp=df.to_numpy()
feature=nmp[:-20,:-1]
label=nmp[:-20,-1]#(210,240)
feature=torch.fft.fft(torch.Tensor(feature))
feature=torch.abs(feature)/240*2
feature=feature.detach().numpy()
sum=1
li=[]
for i in range(feature.shape[0]):
    index=feature[i,:]>=0.2
    index=index.astype(np.int)
    index=np.nonzero(index)

    for j in index:
        for j1 in j:
            if j1 not in li:
                li.append(j1)
print(li)
print(len(li))

df = pd.read_csv('train.csv')
df=df.drop(['ID'],axis=1)
nmp=df.to_numpy()
feature=nmp[:-20,:-1]
label=nmp[:-20,-1]#(210,240)
feature=torch.fft.fft(torch.Tensor(feature))
feature=torch.abs(feature)/240*2
feature=feature[:,li]
feature=feature.detach().numpy()
test_feature=nmp[-20:,:-1]
test_label=nmp[-20:,-1]#(210,240)

test_feature=torch.fft.fft(torch.Tensor(test_feature))
test_feature=torch.abs(test_feature)/240*2
test_feature=test_feature[:,li]
from torch import nn
import torch
label=label.reshape(-1,1)
test_label=test_label.reshape(-1,1)

from sklearn import svm
import matplotlib.pyplot as plt
from sklearn import tree

from sklearn.ensemble import RandomForestClassifier
clf=RandomForestClassifier(n_estimators=2000,max_depth=8) # .SVC()就是 SVM 的方程,参数 kernel 为线性核函数
# 训练分类器
准确率0.83效果不太好。
import sklearn
from sklearn.metrics import accuracy_score
clf.fit(feature, label)
w=clf.predict(feature)
pr=accuracy_score(label, w)
print(pr)

w=clf.predict(test_feature)
pr=accuracy_score(test_label, w)
print(pr)
df = pd.read_csv('test.csv')
df=df.drop(['ID'],axis=1)
nmp=df.to_numpy()
feature=nmp[:,:]
feature=torch.fft.fft(torch.Tensor(feature))
feature=torch.abs(feature)/240*2
feature=torch.Tensor(feature[:,li])
feature=feature.detach().numpy()
out=clf.predict(feature)
out=pd.DataFrame(out)
out.columns = ['CLASS']
w=[]
for k in range(out.shape[0]):
    w.append(k+210)
out['ID']=np.reshape(w,(-1,1))
out[['ID','CLASS']].to_csv('out.csv',index=False)
posted @ 2022-12-04 19:02  祥瑞哈哈哈  阅读(29)  评论(0)    收藏  举报