朴素贝叶斯实例

# 高斯模型API
import numpy as np
from sklearn.naive_bayes import GaussianNB
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

data = load_digits()
feature = data.data
target = data.target
x_train, x_test, y_train, y_test = train_test_split(feature, target,
                                                    test_size=0.2, random_state=2020)
nb = GaussianNB()
nb.fit(x_train, y_train)
score = nb.score(x_test, y_test)
print(score)
# 测试集预测值
print(nb.predict(x_test[5].reshape(1,-1)))
# 测试集真实值
print(y_test[5])
# 预测每一个的概率
print(nb.predict_log_proba(x_test[5].reshape(1,-1)))
 

 

# 多项式朴素贝叶斯API(新闻分类案例)
import sklearn.datasets  as datasets
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB

mnv = MultinomialNB(alpha=1.0, fit_prior=True, class_prior=None)
news = datasets.fetch_20newsgroups(subset="all")
# print(news)
feature = news.data  # 返回的是列表,列表中是文章
target = news.target  # 返回不同的类别
x_train,x_test, y_train,y_test = train_test_split(feature, target, test_size=0.2)
# 对数进行特征抽取
tf = TfidfVectorizer()
x_train = tf.fit_transform(x_train)
x_test = tf.transform(x_test)
mlt = MultinomialNB(alpha=1)
mlt.fit(x_test,y_test)
y_predict = mlt.predict(x_test)
print("预测结果:",y_predict)
print("真实结果:",y_test)
print("准确率:",mlt.score(x_test,y_test))

 

# 伯努利模型(主要使用在二分类中,多分类建议使用多项式朴素贝叶斯)
import sklearn.datasets  as datasets
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import BernoulliNB
news = datasets.fetch_20newsgroups(subset="all")
feature = news.data  # 返回的是列表,列表中是文章
target = news.target  # 返回不同的类别
x_train,x_test, y_train,y_test = train_test_split(feature, target, test_size=0.25)
# 对数进行特征抽取
tf = TfidfVectorizer()
x_train = tf.fit_transform(x_train)
x_test = tf.transform(x_test)
mlt = BernoulliNB(alpha=1)
mlt.fit(x_test,y_test)
y_predict = mlt.predict(x_test)
print("预测结果:",y_predict)
print("真实结果:",y_test)
print("准确率:",mlt.score(x_test,y_test))

 

posted @ 2021-07-07 14:39  邓居旺  阅读(155)  评论(0)    收藏  举报