#!/usr/bin/env python3# -*- coding: utf-8 -*-"""
Created on Thu Jun 7 18:15:30 2018
@author: luogan
"""from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import SparkSession
spark= SparkSession\
.builder \
.appName("dataFrame") \
.getOrCreate()
# Load and parse the data file, converting it to a DataFrame.
data = spark.read.format("libsvm").load("/home/luogan/lg/softinstall/spark-2.2.0-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt")
# Index labels, adding metadata to the label column.# Fit on whole dataset to include all labels in index.
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
# Automatically identify categorical features, and index them.# Set maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer =\
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a GBT model.
gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)
# Chain indexers and GBT in a Pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])
# Train model. This also runs the indexers.
model = pipeline.fit(trainingData)
# Make predictions.
predictions = model.transform(testData)
# Select example rows to display.
predictions.select("prediction", "indexedLabel", "features").show(5)
# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(
labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g" % (1.0 - accuracy))
gbtModel = model.stages[2]
print(gbtModel) # summary only
+----------+------------+--------------------+|prediction|indexedLabel| features|
+----------+------------+--------------------+
| 1.0| 1.0|(692,[95,96,97,12...|
| 1.0| 1.0|(692,[121,122,123...|
| 1.0| 1.0|(692,[122,123,124...|
| 1.0| 1.0|(692,[124,125,126...|
| 1.0| 1.0|(692,[124,125,126...|
+----------+------------+--------------------+
only showing top 5 rows
Test Error = 0.0571429
GBTClassificationModel (uid=GBTClassifier_483a8eddb2c54d041fae) with 10 trees