from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
spark.conf.set("spark.executor.memory", "10g")
sc = spark.sparkContext
sql = spark.sql
# hdfs文件读取
rdd=sc.hadoopFile(
data_path,
inputFormatClass='com.bytedance.hadoop.mapred.PBInputFormat',
keyClass='org.apache.hadoop.io.BytesWritable',
valueClass='org.apache.hadoop.io.BytesWritable'
)
def shuffle(instances):
for instance in instances:
# 序列化,并生成shuffle key
yield random.randint(0, 100000), instance
def serialize(line):
_, instance = line
uid = instance.line_id.uid
gid = instance.line_id.gid
sort_id = (str(uid) +'#' + str(gid)).encode()
data = instance.SerializeToString()
return sort_id, data
# shuffle
rdd.mapPartitions(shuffle).sortByKey()
# 写入hdfs
rdd.map(serialize).saveAsHadoopFile(pb_output_path,
outputFormatClass='com.bytedance.hadoop.mapred.PBOutputFormat',
keyClass='org.apache.hadoop.io.BytesWritable',
valueClass='org.apache.hadoop.io.BytesWritable')
# hive表数据读取
source_df = sql(READ_SQL)
rdd = source_df.rdd
# hive表数据写入
columns = ['uid', 'tag', 'c3_300_labels', 'embedding']
df = output_rdd1.toDF(columns,sampleRatio=0.01)
df.createOrReplaceTempView("tmpv")
sql(WRITE_SQL)