openGauss源码解析(188)

openGauss源码解析:AI技术(35)

8.7.5 MADlib在openGauss上的使用示例

这里以通过支持向量机算法进行房价分类为例,演示具体的使用方法。

(1) 数据集准备,代码如下:

DROP TABLE IF EXISTS houses;

CREATE TABLE houses (id INT, tax INT, bedroom INT, bath FLOAT, price INT, size INT, lot INT);

INSERT INTO houses VALUES

(1 , 590 , 2 , 1 , 50000 , 770 , 22100),

(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000),

(3 , 20 , 3 , 1 , 22500 , 1060 , 3500),

(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000),

(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000),

(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000),

(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000);

(2) 模型训练

① 训练前配置相应schema和兼容性参数,代码如下:

SET search_path="$user",public,madlib;

SET behavior_compat_options = 'bind_procedure_searchpath';

② 使用默认的参数进行训练,分类的条件为‘price < 100000’,SQL语句如下:

DROP TABLE IF EXISTS houses_svm, houses_svm_summary;

SELECT madlib.svm_classification('public.houses','public.houses_svm','price < 100000','ARRAY[1, tax, bath, size]');

(3) 查看模型,代码如下:

\x on

SELECT * FROM houses_svm;

\x off

结果如下:

-[ RECORD 1 ]------+-----------------------------------------------------------------

coef | {.113989576847,-.00226133300602,-.0676303607996,.00179440841072}

loss | .614496714256667

norm_of_gradient | 108.171180769224

num_iterations | 100

num_rows_processed | 15

num_rows_skipped | 0

dep_var_mapping | {f,t}

(4) 进行预测,代码如下:

DROP TABLE IF EXISTS houses_pred;

SELECT madlib.svm_predict('public.houses_svm','public.houses','id','public.houses_pred');

(5) 查看预测结果,代码如下:

SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred USING (id) ORDER BY id;

结果如下:

id | tax | bedroom | bath | price | size | lot | prediction | decision_function | actual

----+------+---------+------+--------+------+-------+------------+-------------------+--------

1 | 590 | 2 | 1 | 50000 | 770 | 22100 | t | .09386721875 | t

2 | 1050 | 3 | 2 | 85000 | 1410 | 12000 | t | .134445058042 | t

14 | 2070 | 2 | 3 | 148000 | 1550 | 14000 | f | -1.9885277913972 | f

15 | 650 | 3 | 1.5 | 65000 | 1450 | 12000 | t | 1.1445697772786 | t

(15 rows)

查看误分率,代码如下:

SELECT COUNT(*) FROM houses_pred JOIN houses USING (id) WHERE houses_pred.prediction != (houses.price < 100000);

结果如下:

count

-------

3

(1 row)

(6) 使用svm其他核进行训练,代码如下:

DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;

SELECT madlib.svm_classification( 'public.houses','public.houses_svm_gaussian','price < 100000','ARRAY[1, tax, bath, size]','gaussian','n_components=10', '', 'init_stepsize=1, max_iter=200' );

进行预测,并查看训练结果。

DROP TABLE IF EXISTS houses_pred_gaussian;

SELECT madlib.svm_predict('public.houses_svm_gaussian','public.houses','id', 'public.houses_pred_gaussian');

SELECT COUNT(*) FROM houses_pred_gaussian JOIN houses USING (id) WHERE houses_pred_gaussian.prediction != (houses.price < 100000);

结果如下:

count

-------+

0

(1 row)

(7) 其他参数

除了指定不同的核方法外,还可以指定迭代次数、初始参数,比如init_stepsize,max_iter,class_weight等。

posted @ 2024-05-06 10:45  openGauss-bot  阅读(12)  评论(0)    收藏  举报