spark RDD 实现 autoKMeans算法 spark 实现分布式无监督聚类算法
package scala.learningRDD
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer
object KmsModelDef{
def eudist(v1: Array[Double], v2: Array[Double]): Double = {
require(v1.length == v2.length, "Vectors must be of the same length.")
// sqrt(v1.zip(v2).map { case (a, b) => math.pow(a - b, 2) }.sum)
v1.zip(v2).map { case (a, b) => if (a==b && a>0 && b>0) 1 else 0}.sum.toDouble/
v1.zip(v2).map { case (a, b) => if (a>0 || b>0) 1 else 0}.sum.toDouble
}
}
class ManualKMeansModel(spark: SparkSession,similarity:Double, iter:Int, acc:Double) {
// val spark: SparkSession = SparkSession.builder()
// .master("local[*]").appName("autoKMeans").getOrCreate()
import spark.implicits._
val sc: SparkContext = spark.sparkContext
var data: RDD[Vector]=_
var centers:RDD[(Double,Vector)]=_
var center_data:RDD[(Double,Double,Double)]=_
val smllt:Double = similarity
val iters: Long = iter
val accs:Double = acc
// var data_size:Long = _
// 优化k_meas_jj方法
def k_meas_jj_optimized(): Unit = {
// 初始中心点查找优化
val maxtuple0: (Double, Double) = this.data.map(v => (v.toArray.head, v.toArray.tail.sum))
.max()(Ordering.by(_._2)) // 使用max替代reduce
val cts_0: Vector = data.filter(v => v(0) == maxtuple0._1).first()
//this.centers = sc.parallelize(Seq((0, cts_0)))
var iteval = 0.0
var n = 1
val centersList: ArrayBuffer[(Double, Vector)] = new ArrayBuffer[(Double, Vector)]()
centersList += ((0, cts_0))
while (iteval < this.smllt && n < 1000) { // 添加安全上限
val centersBroadcast: Broadcast[Map[Double, Vector]] = sc.broadcast(centersList.toMap)
// 使用mapPartitions减少shuffle
val farthestPoint: (Vector, Double) = this.data.mapPartitions { partition =>
val localCenters: Map[Double, Vector] = centersBroadcast.value
partition.map { vector =>
val minDistance = localCenters.values.map { center =>
KmsModelDef.eudist(center.toArray.tail, vector.toArray.tail)
}.min
(vector, minDistance)
}
}.max()(Ordering.by(_._2)) // 找到距离最远的点
iteval = farthestPoint._2
if (iteval > this.smllt) return
val newCenter: (Double, Vector) = (n.toDouble, farthestPoint._1)
centersList += newCenter
this.centers = sc.parallelize(centersList)
//(可选但推荐)任务完成后,解除广播变量以释放内存
centersBroadcast.unpersist()
n += 1
}
}
def assign_clusters_optimized(): Unit = {
// 1. 将聚类中心(this.centers)收集到Driver端,并转换为Map等易于查找的数据结构
// 注意:要确保centers这个RDD本身不会太大,通常K均值聚类中心数在几十到几千个,是适合广播的。
val centersLocalMap: Map[Double, Vector] = this.centers.collect().toMap
// 2. 使用广播变量将中心点数据分发到各个Executor
val centersBroadcast: Broadcast[Map[Double, Vector]] = sc.broadcast(centersLocalMap)
// 3. 直接对数据RDD(this.data)进行映射操作,不再需要cartesian
val rsl_1: RDD[(Double, Double, Double)] = this.data.flatMap { dataPoint =>
// 在每个Executor上,通过.value获取广播的中心点Map
val localCenters = centersBroadcast.value
// 遍历所有中心点,为当前数据点计算到每个中心的距离
localCenters.map { case (centerId, centerVector) =>
val distance = KmsModelDef.eudist(centerVector.toArray.tail, dataPoint.toArray.tail)
(centerId, dataPoint.toArray.head, distance) // 假设dataPoint.toArray.head是数据点ID
}
}
// 4. 后续找出每个数据点的最远中心点的逻辑保持不变
val maxByGroup: RDD[(Double, Double, Double)] = rsl_1
.map(tuple => (tuple._2, tuple)) // 以数据点ID为Key
.reduceByKey { (t1, t2) =>
if (t1._3 > t2._3) t1 else t2
}.map(_._2)
this.center_data = maxByGroup
// 5. (可选但推荐)任务完成后,解除广播变量以释放内存
centersBroadcast.unpersist()
}
// 优化update_centers方法
def update_centers_optimized(): Double = {
val old_centers: Double = this.center_data.map(v => v._3).sum()
// 使用join代替cartesian
val centerDataKV: RDD[(Double, (Double, Double))] = this.center_data.map(tuple => (tuple._2, (tuple._1, tuple._3)))
val dataKV: RDD[(Double, Vector)] = this.data.map(v => (v.toArray.head, v))
val joined: RDD[(Double, ((Double, Double), Vector))] = centerDataKV.join(dataKV)
// 使用aggregateByKey进行高效聚合
val newCentersInfo: RDD[(Double, (Vector, Int))] = joined.map { case (_, ((centerId, _), vector)) =>
(centerId, (vector, 1))
}.reduceByKey { case ((v1, c1), (v2, c2)) =>
// 这里应该是计算均值,但根据您的逻辑需要调整
(v1, c1 + c2) // 示例聚合
}
// 更新中心点逻辑
val updatedCenters: RDD[(Double, Vector)] = newCentersInfo.map { case (centerId, (vector, count)) =>
// 根据您的业务逻辑计算新中心点
(centerId, vector)
}
this.centers = updatedCenters
this.assign_clusters_optimized() // 更新 this.center_data
val new_centers: Double = this.center_data.map(v => v._3).sum()
Math.abs(new_centers - old_centers)
}
def fit(df:RDD[Vector]): Unit ={
this.data = df
// this.data_size = df.count()
// 初始化中心点
k_meas_jj_optimized()
var n:Long = 0
var accing = Double.MaxValue
while (accing> this.accs && n<this.iters) {
assign_clusters_optimized()
accing = update_centers_optimized()
println(accing,n)
n = n+1
}
// 打印运行结果 centers 和 center_data
this.centers.toDF().show()
this.center_data.toDF().show()
}
}
object demo{
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName("autoKMeans")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryoserializer.buffer.max", "256m")
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.config("spark.default.parallelism", "10") // 根据集群调整
.config("spark.sql.shuffle.partitions", "10")
.getOrCreate()
val sc: SparkContext = spark.sparkContext
val dataseq: Seq[Vector] = Seq(
Vectors.dense(1, 0, 0, 1,1),
Vectors.dense(2, 0, 0, 1,1),
Vectors.dense(3, 1, 0, 0,1),
Vectors.dense(4, 1, 0, 0,1),
Vectors.dense(5, 1, 1, 1,1),
Vectors.dense(6, 1, 0, 1,1),
Vectors.dense(7, 0, 0, 1,1),
Vectors.dense(8, 0, 0, 1,1),
Vectors.dense(9, 1, 0, 0,1),
Vectors.dense(10, 1, 0, 0,1),
Vectors.dense(11, 1, 1, 1,1),
Vectors.dense(12, 1, 1, 0,1),
Vectors.dense(13, 1, 0, 1,0),
Vectors.dense(14, 1, 0, 1,0)
)
val data: RDD[Vector] = sc.parallelize(dataseq)
val model:ManualKMeansModel = new ManualKMeansModel(spark,0.65,300,0.00001)
model.fit(data)
// val r = model.eudist(Vectors.dense(6, 1, 1, 1).toArray.tail,Vectors.dense(4, 1, 0, 0).toArray.tail)
// println(r)
spark.stop()
}
}
自动化学习。

浙公网安备 33010602011771号