openGauss源码解析(188)
openGauss源码解析:AI技术(35)
8.7.5 MADlib在openGauss上的使用示例
这里以通过支持向量机算法进行房价分类为例,演示具体的使用方法。
(1) 数据集准备,代码如下:
DROP TABLE IF EXISTS houses;
CREATE TABLE houses (id INT, tax INT, bedroom INT, bath FLOAT, price INT, size INT, lot INT);
INSERT INTO houses VALUES
(1 , 590 , 2 , 1 , 50000 , 770 , 22100),
(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000),
(3 , 20 , 3 , 1 , 22500 , 1060 , 3500),
…
(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000),
(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000),
(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000),
(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000);
(2) 模型训练
① 训练前配置相应schema和兼容性参数,代码如下:
SET search_path="$user",public,madlib;
SET behavior_compat_options = 'bind_procedure_searchpath';
② 使用默认的参数进行训练,分类的条件为‘price < 100000’,SQL语句如下:
DROP TABLE IF EXISTS houses_svm, houses_svm_summary;
SELECT madlib.svm_classification('public.houses','public.houses_svm','price < 100000','ARRAY[1, tax, bath, size]');
(3) 查看模型,代码如下:
\x on
SELECT * FROM houses_svm;
\x off
结果如下:
-[ RECORD 1 ]------+-----------------------------------------------------------------
coef | {.113989576847,-.00226133300602,-.0676303607996,.00179440841072}
loss | .614496714256667
norm_of_gradient | 108.171180769224
num_iterations | 100
num_rows_processed | 15
num_rows_skipped | 0
dep_var_mapping | {f,t}
(4) 进行预测,代码如下:
DROP TABLE IF EXISTS houses_pred;
SELECT madlib.svm_predict('public.houses_svm','public.houses','id','public.houses_pred');
(5) 查看预测结果,代码如下:
SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred USING (id) ORDER BY id;
结果如下:
id | tax | bedroom | bath | price | size | lot | prediction | decision_function | actual
----+------+---------+------+--------+------+-------+------------+-------------------+--------
1 | 590 | 2 | 1 | 50000 | 770 | 22100 | t | .09386721875 | t
2 | 1050 | 3 | 2 | 85000 | 1410 | 12000 | t | .134445058042 | t
…
14 | 2070 | 2 | 3 | 148000 | 1550 | 14000 | f | -1.9885277913972 | f
15 | 650 | 3 | 1.5 | 65000 | 1450 | 12000 | t | 1.1445697772786 | t
(15 rows)
查看误分率,代码如下:
SELECT COUNT(*) FROM houses_pred JOIN houses USING (id) WHERE houses_pred.prediction != (houses.price < 100000);
结果如下:
count
-------
3
(1 row)
(6) 使用svm其他核进行训练,代码如下:
DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;
SELECT madlib.svm_classification( 'public.houses','public.houses_svm_gaussian','price < 100000','ARRAY[1, tax, bath, size]','gaussian','n_components=10', '', 'init_stepsize=1, max_iter=200' );
进行预测,并查看训练结果。
DROP TABLE IF EXISTS houses_pred_gaussian;
SELECT madlib.svm_predict('public.houses_svm_gaussian','public.houses','id', 'public.houses_pred_gaussian');
SELECT COUNT(*) FROM houses_pred_gaussian JOIN houses USING (id) WHERE houses_pred_gaussian.prediction != (houses.price < 100000);
结果如下:
count
-------+
0
(1 row)
(7) 其他参数
除了指定不同的核方法外,还可以指定迭代次数、初始参数,比如init_stepsize,max_iter,class_weight等。

浙公网安备 33010602011771号