spark使用jdbc批次提交方式写入phoniex的工具类

一、需求:spark写入phoniex

二、实现方式

1.官网方式

 dataFrame.write
        .format("org.apache.phoenix.spark")
        .mode("overwrite")
        .option("table", table)
        .option("zkUrl", zkUrl)
        .option("skipNormalizingIdentifier", true)
        .save()

这个方式底层是使用MapReduce的RecordWriter实现类PhoenixRecordWriter通过jdbc方式写入

 但是默认的batchsize是1000,所以插入速度极慢,但是官网没有说明写入的参数设置,需要去源码里面寻找一下

 所以可以通过设置参数来提升速度

  .option(PhoenixConfigurationUtil.UPSERT_BATCH_SIZE,batch)

二、自己实现jdbc的通用方式(任何jdbc方式都可以写入)

代码:

object JdbcUtils {
  def jdbcBatchInsert(dataFrame: DataFrame, table: String, url: String, pro: Properties, batch: Int): Unit = {
    val fields: Array[String] = dataFrame.schema.fieldNames
    val schema: Array[StructField] = dataFrame.schema.toArray
    val numFields = fields.length
    val fieldsSql = fields.map(str => "\"".concat(str).concat("\"")).mkString("(", ",", ")")
    val charSql = fields.map(str => "?").mkString(",")
    val setters: Array[JDBCValueSetter] = schema.map(f => makeSetter(f.dataType))
    val insertSql = s"upsert into $table $fieldsSql values ($charSql) "
    System.err.println("插入sql:" + insertSql)
    val start = System.currentTimeMillis()
    dataFrame.rdd.foreachPartition(partition => {
      val connection = DriverManager.getConnection(url, pro)
      try {
        connection.setAutoCommit(false)
        val pstmt: PreparedStatement = connection.prepareStatement(insertSql)
        var count = 0
        var cnt = 0
        partition.foreach(row => {
          for (i <- 0 until numFields) {
            if (row.isNullAt(i)) {
              pstmt.setNull(i + 1, getJdbcType(schema(i).dataType))
            } else {
              setters(i).apply(pstmt, row, i)
            }
          }
          pstmt.addBatch()
          count += 1
          if (count % batch == 0) {
            pstmt.executeBatch()
            connection.commit()
            cnt += 1
            println(s"${TaskContext.get.partitionId}分区,提交第${cnt}次,${count}tiao")
          }
        })
        pstmt.executeBatch()
        connection.commit()
        println(s"第${TaskContext.get.partitionId}分区,共提交第${cnt},${count}条")
      } finally {
        connection.close()
      }
    })
    val end = System.currentTimeMillis()
    println(s"插入表$table,共花费时间${(end - start) / 1000}秒")
  }


  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  /**
   * 类型匹配  如果有其他类型 自行添加
   *
   * @param dataType
   * @return
   */
  def makeSetter(dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        if (row.isNullAt(pos)) {
          stmt.setNull(pos + 1, java.sql.Types.INTEGER)
        } else {
          stmt.setInt(pos + 1, row.getInt(pos))
        }
    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setString(pos + 1, row.getString(pos))

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    /* case ArrayType(et, _) =>
       // remove type length parameters from end of type name
       val typeName = getJdbcType(et, dialect).databaseTypeDefinition
         .toLowerCase(Locale.ROOT).split("\\(")(0)
       (stmt: PreparedStatement, row: Row, pos: Int) =>
         val array = conn.createArrayOf(
           typeName,
           row.getSeq[AnyRef](pos).toArray)
         stmt.setArray(pos + 1, array)*/

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }

  /**
   * sql类型匹配 如果有其他类型 自行添加
   *
   * @param dt
   * @return
   */
  private def getJdbcType(dt: DataType): Int = {
    dt match {
      case IntegerType => java.sql.Types.INTEGER
      case LongType => java.sql.Types.BIGINT
      case DoubleType => java.sql.Types.DOUBLE
      case StringType => java.sql.Types.VARCHAR
      case _ => java.sql.Types.VARCHAR
    }
  }
}  

 测试:

#config是个map集合要不要都可以
val connectionProperties = new Properties();
  connectionProperties.setProperty(QueryServices.MAX_MUTATION_SIZE_ATTRIB, config.getOrDefault("phoenix.mutate.maxSize", "500000")); //改变默认的500000
  connectionProperties.setProperty(QueryServices.MUTATE_BATCH_SIZE_BYTES_ATTRIB, config.getOrDefault("phoenix.mutate.batchSizeBytes", "1073741824000"))

 val batch = config.getOrDefault("phoenix.insert.batchSize", "50000").toInt

//调用插入方法
 JdbcUtils.jdbcBatchInsert(dataFrame, table, phoenixUrl, connectionProperties, batch);

  

 

posted @ 2021-01-22 15:00  夜半钟声到客船  阅读(690)  评论(0编辑  收藏  举报