// 二分查找
def fetchBinarySearch(trainItems: Array[(String, Double)], target: Double): String = {
// val trainItems = Array(("1", 0), ("2", 1), ("3", 3), ("4", 4), ("5", 6))
// val target = 6.0000000000018032
if (trainItems.length == 0) {
""
} else {
var left = 0
var right = trainItems.length - 1
while(left < right) {
val mid = ((left + right)/2).toInt
if (trainItems(mid)._2 < target) {
left = mid + 1
} else {
right = mid
}
}
trainItems(left)._1
}
}
// 获取采样负样本用户集
def fetchFullSampleItemsUdf(trainItems: Array[(String, Double)], trainItemsSize: Int, negNum: Int): UserDefinedFunction = udf(
(app_info: String) => {
val sampleItems = app_info.split(",").map(t => t.split(":")(0)).toBuffer
val sampleItemsSet = scala.collection.mutable.Set[String]() ++ sampleItems.toSet
val posNum = sampleItems.size
var tmpNegNum = posNum*negNum
// val trainItems = Array(("1", 0.1), ("2", 0.2), ("3", 0.3), ("4", 0.4))
val probabilities = DenseVector(trainItems.map(_._2))
while(tmpNegNum > 0) {
// // 随机负采样
// val randomIndex = (new Random).nextInt(trainItemsSize)
// val negItem = trainItems(randomIndex)._1
// 带权负采样(二分查找)
val randomTarget = (new Random).nextDouble()
val negItem = fetchBinarySearch(trainItems, randomTarget)
// // 带权负采样(调用接口函数)
// val randomIndex = new Multinomial(probabilities).sample(1).head
// val negItem = trainItems(randomIndex)._1
if (!sampleItemsSet.contains(negItem)) {
sampleItems.append(negItem)
tmpNegNum = tmpNegNum - 1
}
}
sampleItems.zipWithIndex.map{
case (item, i) =>
val label = if (i < posNum) 1 else 0
(item, label)
}
}
)
// 样本数据拼接
def fetchSampleData(spark: SparkSession, day: String, part: String, negNum: Int): DataFrame = {
// val part = "0"
val targetData = fetchTargetData(spark, day, part)
val userMap = {
targetData.select("user_id").dropDuplicates("user_id").rdd
.map {row =>
val user_id = row.getAs[String]("user_id")
(user_id, "1")
}.collect().toMap
}
// val trainItems = fetchItemSampleData(spark, day).dropDuplicates("appid").rdd.map{
// row => row.getAs[String]("appid")
// }.collect()
val win = Window.partitionBy("day")
val win2 = Window.partitionBy("day").orderBy("pv")
val win3 = Window.partitionBy("day").orderBy("rank")
val trainItems = {
fetchItemSampleData(spark, day).groupBy("day", "appid").agg(expr("power(count(user_id), 0.75) as pv"))
.withColumn("pv_sum", sum("pv").over(win))
.withColumn("fw", col("pv")/col("pv_sum"))
.withColumn("rank", row_number().over(win2))
.withColumn("fp", sum("fw").over(win3)).rdd //相同值累计求和有问题
.map{row =>
val appid = row.getAs[String]("appid")
val fp = row.getAs[Double]("fp")
val pv = row.getAs[Double]("pv")
val fw = row.getAs[Double]("fw")
(appid, fp)
}.collect()
}.sortBy(_._2)
// trainItems.reverse.take(10)
val trainItemsSize = trainItems.length
// targetData.
// withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, 5)(col("app_info"))).
// withColumn("fullSampleItems", explode(col("fullSampleItems"))).
// withColumn("item_id", col("fullSampleItems").getField("_1")).
// withColumn("target", col("fullSampleItems").getField("_2")).
// groupBy("item_id").agg(expr("count(if(target == '1', user_id, null)) as pos_pv"),
// expr("count(if(target == '0', user_id, null)) as neg_pv")).orderBy(desc("pos_pv")).
// show(10, false)
val userFeatures = fetchUserFeatures(spark, day, userMap)
val itemFeatures = fetchItemFeatures(spark, day)
val sampleData = {
targetData.join(userFeatures, Seq("user_id"), "left")
.withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, negNum)(col("app_info")))
.withColumn("fullSampleItems", explode(col("fullSampleItems")))
.withColumn("item_id", col("fullSampleItems").getField("_1"))
.withColumn("target", col("fullSampleItems").getField("_2"))
.join(broadcast(itemFeatures), Seq("item_id"), "left")
}
sampleData
}