Spark- 求最受欢迎的TopN课程

数据库操作工具类

package com.rz.mobile_tag.utils

import java.sql.{Connection, DriverManager, PreparedStatement}

object MySQLUtils {
  /**
    * 获取数据库连接
    * @return
    */
  def getConnection()={
    DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?user=root&password=root")
  }

  /**
    * 释放数据库连接等资源
    * @param conn
    * @param pstmt
    */
  def release(conn:Connection, pstmt:PreparedStatement)={
    try {
      if (pstmt !=null){
        pstmt.close()
      }
    }catch {
      case e:Exception => e.printStackTrace()
    }finally {
      if (conn != null){
        conn.close()
      }
    }
  }
}

数据操作类:优化点(使用批量插入数据库,提交使用batch操作)

package com.rz.mobile_tag.dao

import java.sql.{Connection, PreparedStatement}

import com.rz.mobile_tag.bean.DayVideoAccessStat
import com.rz.mobile_tag.utils.MySQLUtils

import scala.collection.mutable.ListBuffer

object StatDao {
  /**
    * 批量保存DayVideoAccessStat到数据库
    * @param list
    */
  def insertDayVideoAccessTopN(list: ListBuffer[DayVideoAccessStat]): Unit = {
    var connection:Connection = null;
    var pstmt:PreparedStatement = null;

    try {
      connection = MySQLUtils.getConnection()

      connection.setAutoCommit(false) // 设置手动提交

      val sql ="insert into day_video_access_topn_stat(day,cms_id,times) value(?,?,?)"
      pstmt = connection.prepareStatement(sql)

      for (ele <- list){
        pstmt.setString(1, ele.day)
        pstmt.setLong(2,ele.cmsId)
        pstmt.setLong(3, ele.times)
        pstmt.addBatch()
      }
      pstmt.executeBatch() // 执行批量处理
      connection.commit() // 手工提交
    }catch {
      case e:Exception =>e.printStackTrace()
    }finally {
      MySQLUtils.release(connection, pstmt)
    }
  }
}

 

业务实现类

package com.rz.mobile_tag.log

import com.rz.mobile_tag.bean.DayVideoAccessStat
import com.rz.mobile_tag.dao.StatDao
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._

import scala.collection.mutable.ListBuffer

object TopNStatJob {



  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName(s"${this.getClass.getSimpleName}")
      .config("spark.sql.sources.partitionColumnTypeInference.enabled","false")
      .master("local[2]")
      .getOrCreate()


    val accessDF: DataFrame = spark.read.format("parquet").load(args(0))
    accessDF.printSchema()
    accessDF.show(false)

    // 最受欢迎的TopN课程
    videoAccessTopNStat(spark, accessDF)

    spark.stop()
  }

  /**
    * 最受欢迎的TopN课程
    * @param spark
    * @param accessDF
    */
  def videoAccessTopNStat(spark: SparkSession, accessDF: DataFrame) = {
//    import spark.implicits._
//    val videoAccesssTopNDF: Dataset[Row] = accessDF.filter($"day" === "20190506" && $"cmsType" === "video")
//      .groupBy("day", "cmsId")
//      .agg(count("cmsId")).as("times").orderBy($"times".desc)
//    videoAccesssTopNDF.show(false)

    accessDF.createOrReplaceTempView("access_logs")

    // 使用SQL方式进行统计
    val videoAccesssTopNDF: DataFrame = spark.sql("select day, cmsId, count(1) as times from access_logs" +
      " where day = '20190506' and cmsType = 'video' group by day, cmsId" +
      " order by times desc")

    //videoAccesssTopNDF.show(false)

    // 将统计数据写入到MySQL中
    try{
      videoAccesssTopNDF.foreachPartition(partitionOfRecords=>{
        val list = new ListBuffer[DayVideoAccessStat]
        partitionOfRecords.foreach(info =>{
          val day = info.getAs[String]("day")
          val cmsId = info.getAs[Long]("cmsId")
          val times = info.getAs[Long]("times")

          list.append(DayVideoAccessStat(day, cmsId, times))
        })
        StatDao.insertDayVideoAccessTopN(list)
      })
    }catch {
      case e:Exception => e.printStackTrace()
    }
  }
}

 

posted @ 2019-05-06 12:41  RZ_Lee  阅读(483)  评论(0编辑  收藏  举报