pyspark学习笔记

spark 分布式存储

# Don't change this query
query = "FROM flights SELECT * LIMIT 10"

# Get the first 10 rows of flights
flights10 = spark.sql(query)

# Show the results
flights10.show()

Pandafy a Spark DataFrame

使用pandas的形式可视化数据框

<script.py> output:
      origin  ...    N
    0    SEA  ...    8
    1    SEA  ...   98
    2    SEA  ...    2
    3    SEA  ...  450
    4    PDX  ...  144
    
    [5 rows x 3 columns]
# Don't change this query
query = "SELECT origin, dest, COUNT(*) as N FROM flights GROUP BY origin, dest"

# Run the query
flight_counts = spark.sql(query)

# Convert the results to a pandas DataFrame
pd_counts = flight_counts.toPandas()

# Print the head of pd_counts
print(pd_counts.head())

读文件

# Don't change this file path
file_path = "/usr/local/share/datasets/airports.csv"

# Read in the airports data
airports = spark.read.csv(file_path, header=True)

# Show the data
airports.show()

Use the spark.table() method with the argument "flights" to create a DataFrame containing the values of the flights table in the .catalog. Save it as flights.
Show the head of flights using flights.show(). The column air_time contains the duration of the flight in minutes.
Update flights to include a new column called duration_hrs, that contains the duration of each flight in hours.

以下操作均是对dataframe进行的


# Create the DataFrame flights
flights = spark.table("flights")

# Show the head
flights.show()

# Add duration_hrs
flights = flights.withColumn("duration_hrs", flights.air_time/60)

Filtering Data

筛选数据

# Filter flights by passing a string
long_flights1 = flights.filter("distance > 1000")

# Filter flights by passing a column of boolean values
long_flights2 = flights.filter(flights.distance > 1000)

# Print the data to check they're equal
long_flights1.show()
long_flights2.show()

.show()是展示数据

按照指定的条件筛选,和pandas非常的像

# Select the first set of columns
selected1 = flights.select("tailnum", "origin", "dest")

# Select the second set of columns
temp = flights.select(flights.origin, flights.dest, flights.carrier)

#这个列名的选择很像R里面的
# Define first filter
filterA = flights.origin == "SEA"

# Define second filter
filterB = flights.dest == "PDX"

# Filter the data, first by filterA then by filterB
selected2 = temp.filter(filterA).filter(filterB)

alias()

selectExpr

可以重命名列

selectExpr方法本质与select方法中使用expr函数是一样的,都是用来构建复杂的表达式,下面我们可以看几个例子。

df.selectExpr("appid as newappid").show()

上面这行代码,就是选择appid列并将appid重命名为newappid。

df.selectExpr("count(distinct(appid)) as count1", "count(distinct(brand)) as count2").show()

上面这行代码,就是计算appid去重后的数量,还有brand去重后的数量。

聚合函数

# Find the shortest flight from PDX in terms of distance
flights.filter(flights.origin == "PDX").groupBy().min("distance").show()

# Find the longest flight from SEA in terms of air time
flights.filter(flights.origin == "SEA").groupBy().max("air_time").show()


内置函数

```r
# Find the shortest flight from PDX in terms of distance
flights.filter(flights.origin == "PDX").groupBy().min("distance").show()

# Find the longest flight from SEA in terms of air time
flights.filter(flights.origin == "SEA").groupBy().max("air_time").show()


<script.py> output:
    +-------------+
    |min(distance)|
    +-------------+
    |          106|
    +-------------+
    
    +-------------+
    |max(air_time)|
    +-------------+
    |          409|
    +-------------+

# Average duration of Delta flights
flights.filter(flights.carrier == "DL").filter(flights.origin == "SEA").groupBy().avg("air_time").show()

# Total hours in the air
flights.withColumn("duration_hrs", flights.air_time/60).groupBy().sum("duration_hrs").show()

withColumn给df新增一列

Create a DataFrame called by_plane that is grouped by the column tailnum.
Use the .count() method with no arguments to count the number of flights each plane made.
Create a DataFrame called by_origin that is grouped by the column origin.
Find the .avg() of the air_time column to find average duration of flights from PDX and SEA.

# Group by tailnum
by_plane = flights.groupBy("tailnum")

# Number of flights each plane made
by_plane.count().show()

# Group by origin
by_origin = flights.groupBy("origin")

# Average duration of flights from PDX and SEA
by_origin.avg("air_time").show()

其实也是需要导入聚合函数的包的

# Import pyspark.sql.functions as F
import pyspark.sql.functions as F

# Group by month and dest
by_month_dest = flights.groupBy("month", "dest")

# Average departure delay by month and destination
by_month_dest.avg("dep_delay").show()

# Standard deviation of departure delay
by_month_dest.agg(F.stddev("dep_delay")).show()

 +-----+----+----------------------+
    |month|dest|stddev_samp(dep_delay)|
    +-----+----+----------------------+
    |   11| TUS|    3.0550504633038935|
    |   11| ANC|    18.604716401245316|
    |    1| BUR|     15.22627576540667|
    |    1| PDX|     5.677214918493858|
    |    6| SBA|     2.380476142847617|
    |    5| LAX|     13.36268698685904|
    |   10| DTW|     5.639148871948674|
    |    6| SIT|                   NaN|
    |   10| DFW|     45.53019017606675|
    |    3| FAI|    3.1144823004794873|
    |   10| SEA|     18.70523227029577|
    |    2| TUS|    14.468356276140469|
    |   12| OGG|     82.64480404939947|
    |    9| DFW|    21.728629347782924|
    |    5| EWR|     42.41595968929191|
    |    3| RDM|      2.16794833886788|
    |    8| DCA|     9.946523680831074|
    |    7| ATL|    22.767001039582183|
    |    4| JFK|     8.156774303176903|
    |   10| SNA|    13.726234873756304|
    +-----+----+----------------------+
    only showing top 20 rows

join 链接表

# Examine the data
airports.show()

# Rename the faa column
airports = airports.withColumnRenamed("faa", "dest")

# Join the DataFrames
flights_with_airports = flights.join(airports, on="dest", how="leftouter")

# Examine the new DataFrame
flights_with_airports.show()

Machine Learning Pipelines

拼接dataframe

# Rename year column
planes = planes.withColumnRenamed("year", "plane_year")

# Join the DataFrames
model_data = flights.join(planes, on="tailnum", how="leftouter")

cast

可以转换列的数据类型

result = table1.join(table1,['字段'],"full").withColumn("名称",col("字段")/col("字段"))

新增一列数据,数据的内容是col("字段")/col("字段")

# To convert the type of a column using the .cast() method, you can write code like this:
dataframe = dataframe.withColumn("col", dataframe.col.cast("new_type"))


# Cast the columns to integers
model_data = model_data.withColumn("arr_delay", model_data.arr_delay.cast("integer"))
model_data = model_data.withColumn("air_time", model_data.air_time.cast("integer"))
model_data = model_data.withColumn("month", model_data.month.cast("integer"))
model_data = model_data.withColumn("plane_year", model_data.plane_year.cast("integer"))
# Create is_late
model_data = model_data.withColumn("is_late", model_data.arr_delay > 0)

# Convert to an integer
model_data = model_data.withColumn("label", model_data.is_late.cast("integer"))

# Remove missing values
model_data = model_data.filter("arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL")
# Create a StringIndexer 字符型
carr_indexer = StringIndexer(inputCol="carrier", outputCol="carrier_index")

# Create a OneHotEncoder 独热编码
carr_encoder = OneHotEncoder(inputCol="carrier_index", outputCol="carrier_fact")

pipeline

# Import Pipeline
from pyspark.ml import Pipeline

# Make the pipeline
flights_pipe = Pipeline(stages=[dest_indexer, dest_encoder, carr_indexer, carr_encoder, vec_assembler])

fit_transform

# Fit and transform the data
piped_data = flights_pipe.fit(model_data).transform(model_data)

划分数据集

# Split the data into training and test sets 类似于train_test_split
training, test = piped_data.randomSplit([.6, .4])

逻辑回归

# Import LogisticRegression
from pyspark.ml.classification import LogisticRegression

# Create a LogisticRegression Estimator
lr = LogisticRegression()

评价指标

# Import the evaluation submodule
import pyspark.ml.evaluation as evals

# Create a BinaryClassificationEvaluator
evaluator = evals.BinaryClassificationEvaluator(metricName="areaUnderROC")

Make a grid

网格搜索

# Import the tuning submodule
import pyspark.ml.tuning as tune

# Create the parameter grid
grid = tune.ParamGridBuilder()

# Add the hyperparameter
grid = grid.addGrid(lr.regParam, np.arange(0, .1, .01))
grid = grid.addGrid(lr.elasticNetParam, [0, 1])

# Build the grid
grid = grid.build()

交叉验证

# Create the CrossValidator
cv = tune.CrossValidator(estimator=lr,
                         estimatorParamMaps=grid,
                         evaluator=evaluator
                         )

模型评估

# Use the model to predict the test set
test_results = best_lr.transform(test)

# Evaluate the predictions
print(evaluator.evaluate(test_results))

drop

# Load the CSV file
aa_dfw_df = spark.read.format('csv').options(Header=True).load('AA_DFW_2018.csv.gz')

# Add the airport column using the F.lower() method
aa_dfw_df = aa_dfw_df.withColumn('airport', F.lower(aa_dfw_df['Destination Airport']))

# Drop the Destination Airport column
aa_dfw_df = aa_dfw_df.drop(aa_dfw_df['Destination Airport'])

# Show the DataFrame
aa_dfw_df.show()


<script.py> output:
    +-----------------+-------------+-----------------------------+-------+
    |Date (MM/DD/YYYY)|Flight Number|Actual elapsed time (Minutes)|airport|
    +-----------------+-------------+-----------------------------+-------+
    |       01/01/2018|         0005|                          498|    hnl|
    |       01/01/2018|         0007|                          501|    ogg|
    |       01/01/2018|         0043|                            0|    dtw|
    |       01/01/2018|         0051|                          100|    stl|
    |       01/01/2018|         0075|                          147|    dca|
    |       01/01/2018|         0096|                           92|    stl|
    |       01/01/2018|         0103|                          227|    sjc|
    |       01/01/2018|         0119|                          517|    ogg|
    |       01/01/2018|         0123|                          489|    hnl|
    |       01/01/2018|         0128|                          141|    mco|
    |       01/01/2018|         0132|                          201|    ewr|
    |       01/01/2018|         0140|                          215|    sjc|
    |       01/01/2018|         0174|                          140|    rdu|
    |       01/01/2018|         0190|                           68|    sat|
    |       01/01/2018|         0200|                          215|    sfo|
    |       01/01/2018|         0209|                          169|    mia|
    |       01/01/2018|         0217|                          178|    las|
    |       01/01/2018|         0229|                          534|    koa|
    |       01/01/2018|         0244|                          115|    cvg|
    |       01/01/2018|         0262|                          159|    mia|
    +-----------------+-------------+-----------------------------+-------+
    only showing top 20 rows

Saving a DataFrame in Parquet format

parquet压缩的意思

# Combine the DataFrames into one 
df3 = df1.union(df2)

# Save the df3 DataFrame in Parquet format
df3.write.parquet('AA_DFW_ALL.parquet', mode='overwrite')

# Read the Parquet file into a new DataFrame and run a count
print(spark.read.parquet('AA_DFW_ALL.parquet').count())

<script.py> output:
    df1 Count: 139359
    df2 Count: 119911
    259270

createOrReplaceTempView

createOrReplaceTempView 的作用是创建一个临时的表 , 一旦创建这个表的会话关闭 , 这个表>也会立马消失 其他的SparkSession 不能共享应已经创建的临时表

createGlobalTempView 创建一个全局的临时表 , 这个表的生命周期是 整个Spark应用程序 ,
只要Spark 应用程序不关闭 , 那么. 这个临时表依然是可以使用的 ,并且这个表对其他的SparkSession共享

filter

Show the distinct VOTER_NAME entries

voter_df.select(voter_df['VOTER_NAME']).distinct().show(40, truncate=False)

Filter voter_df where the VOTER_NAME is 1-20 characters in length

voter_df = voter_df.filter('length(VOTER_NAME) > 0 and length(VOTER_NAME) < 20')

Filter out voter_df where the VOTER_NAME contains an underscore

voter_df = voter_df.filter(~ F.col('VOTER_NAME').contains('_'))

Show the distinct VOTER_NAME entries again

voter_df.select('VOTER_NAME').distinct().show(40, truncate=False)

数据框的列操作

.split(),
.size(),
.getItem().

具体的内置函数可以参考这个

withColumn

类似于R中的mutate

可以新增列,第一个参数是列名,第二个参数是产生的列的值

# Add a column to voter_df for any voter with the title **Councilmember**
voter_df = voter_df.withColumn('random_val',
                               when(voter_df.TITLE == 'Councilmember', F.rand()))
# Show some of the DataFrame rows, noting whether the when clause worked
voter_df.show()

<script.py> output:
    +----------+-------------+-------------------+-------------------+
    |      DATE|        TITLE|         VOTER_NAME|         random_val|
    +----------+-------------+-------------------+-------------------+
    |02/08/2017|Councilmember|  Jennifer S. Gates|  0.272243504035671|
    |02/08/2017|Councilmember| Philip T. Kingston|  0.583853158237357|
    |02/08/2017|        Mayor|Michael S. Rawlings|               null|
    |02/08/2017|Councilmember|       Adam Medrano|0.14137215411622484|
    |02/08/2017|Councilmember|       Casey Thomas| 0.9411379198657572|
    |02/08/2017|Councilmember|Carolyn King Arnold| 0.8379601212162058|
    |02/08/2017|Councilmember|       Scott Griggs|0.18946575456658876|
    |02/08/2017|Councilmember|   B. Adam  McGough| 0.4952465558048347|
    |02/08/2017|Councilmember|       Lee Kleinman| 0.8047711324429991|
    |02/08/2017|Councilmember|      Sandy Greyson|  0.719737910435363|
    |02/08/2017|Councilmember|  Jennifer S. Gates| 0.7558084225784659|
    |02/08/2017|Councilmember| Philip T. Kingston| 0.7274572656632454|
    |02/08/2017|        Mayor|Michael S. Rawlings|               null|
    |02/08/2017|Councilmember|       Adam Medrano|   0.68576094369742|
    |02/08/2017|Councilmember|       Casey Thomas|0.32803656527818037|
    |02/08/2017|Councilmember|Carolyn King Arnold|0.24756136511724325|
    |02/08/2017|Councilmember| Rickey D. Callahan| 0.8197696131561757|
    |01/11/2017|Councilmember|  Jennifer S. Gates| 0.2346485886453783|
    |04/25/2018|Councilmember|     Sandy  Greyson| 0.4073895538345208|
    |04/25/2018|Councilmember| Jennifer S.  Gates| 0.1938743267054036|
    +----------+-------------+-------------------+-------------------+
    only showing top 20 rows
    

when/otherwise

# Add a column to voter_df for a voter based on their position
voter_df = voter_df.withColumn('random_val',
                               when(voter_df.TITLE == 'Councilmember', F.rand())
                               .when(voter_df.TITLE == 'Mayor', 2)
                               .otherwise(0))
# Show some of the DataFrame rows
voter_df.show()

# Use the .filter() clause with random_val
voter_df.filter(voter_df.random_val == 0).show()

when。。otherwise就类似于按照条件查找if else的形式

# Add a column to voter_df for a voter based on their position
voter_df = voter_df.withColumn('random_val',
                               when(voter_df.TITLE == 'Councilmember', F.rand())
                               .when(voter_df.TITLE == 'Mayor', 2)
                               .otherwise(0))

# Show some of the DataFrame rows
voter_df.show()

# Use the .filter() clause with random_val
voter_df.filter(voter_df.random_val == 0).show()

用户自定义函数

除了可以调内置函数之外,还可以自定义函数,这个其实在任何语言里都是一样的,函数嘛,就是实现功能 的

def getFirstAndMiddle(names):
  # Return a space separated string of names
  return ' '.join(names[:-1])

# Define the method as a UDF
udfFirstAndMiddle = F.udf(getFirstAndMiddle, StringType())

# Create a new column using your UDF
voter_df = voter_df.withColumn('first_and_middle_name', udfFirstAndMiddle(voter_df.splits))

# Show the DataFrame
voter_df.show()

Partitioning and lazy processing

# Select all the unique council voters
voter_df = df.select(df["VOTER NAME"]).distinct() #这个是去重的

# Count the rows in voter_df
print("\nThere are %d rows in the voter_df DataFrame.\n" % voter_df.count())

# Add a ROW_ID
voter_df = voter_df.withColumn('ROW_ID', F.monotonically_increasing_id())

# Show the rows with 10 highest IDs in the set
voter_df.orderBy(voter_df.ROW_ID.desc()).show(10)


<script.py> output:
    
    There are 36 rows in the voter_df DataFrame.
    
    +--------------------+-------------+
    |          VOTER NAME|       ROW_ID|
    +--------------------+-------------+
    |        Lee Kleinman|1709396983808|
    |  the  final  201...|1700807049217|
    |         Erik Wilson|1700807049216|
    |  the  final   20...|1683627180032|
    | Carolyn King Arnold|1632087572480|
    | Rickey D.  Callahan|1597727834112|
    |   the   final  2...|1443109011456|
    |    Monica R. Alonzo|1382979469312|
    |     Lee M. Kleinman|1228360646656|
    |   Jennifer S. Gates|1194000908288|
    +--------------------+-------------+
    only showing top 10 rows

展示分区数量的

# Print the number of partitions in each DataFrame
print("\nThere are %d partitions in the voter_df DataFrame.\n" % voter_df.rdd.getNumPartitions()) # 展示分区数量的
print("\nThere are %d partitions in the voter_df_single DataFrame.\n" % voter_df_single.rdd.getNumPartitions())

# Add a ROW_ID field to each DataFrame
voter_df = voter_df.withColumn('ROW_ID', F.monotonically_increasing_id())
voter_df_single = voter_df_single.withColumn('ROW_ID', F.monotonically_increasing_id())

# Show the top 10 IDs in each DataFrame 
voter_df.orderBy(voter_df.ROW_ID.desc()).show(10)
voter_df_single.orderBy(voter_df_single.ROW_ID.desc()).show(10)

<script.py> output:
    
    There are 200 partitions in the voter_df DataFrame.
    
    
    There are 1 partitions in the voter_df_single DataFrame.
    
    +--------------------+-------------+
    |          VOTER NAME|       ROW_ID|
    +--------------------+-------------+
    |        Lee Kleinman|1709396983808|
    |  the  final  201...|1700807049217|
    |         Erik Wilson|1700807049216|
    |  the  final   20...|1683627180032|
    | Carolyn King Arnold|1632087572480|
    | Rickey D.  Callahan|1597727834112|
    |   the   final  2...|1443109011456|
    |    Monica R. Alonzo|1382979469312|
    |     Lee M. Kleinman|1228360646656|
    |   Jennifer S. Gates|1194000908288|
    +--------------------+-------------+
    only showing top 10 rows
    
    +--------------------+------+
    |          VOTER NAME|ROW_ID|
    +--------------------+------+
    |        Lee Kleinman|    35|
    |  the  final  201...|    34|
    |         Erik Wilson|    33|
    |  the  final   20...|    32|
    | Carolyn King Arnold|    31|
    | Rickey D.  Callahan|    30|
    |   the   final  2...|    29|
    |    Monica R. Alonzo|    28|
    |     Lee M. Kleinman|    27|
    |   Jennifer S. Gates|    26|
    +--------------------+------+
    only showing top 10 rows

.rdd

RDD 是 Spark 提供的最重要的抽象概念,它是一种有容错机制的特殊数据集合,可以分布在集群的结点上,以函数式操作集合的方式进行各种并行操作。

cache

缓存机制

start_time = time.time()

# Add caching to the unique rows in departures_df
departures_df = departures_df.distinct().cache() 

# Count the unique rows in departures_df, noting how long the operation takes
print("Counting %d rows took %f seconds" % (departures_df.count(), time.time() - start_time))  #计算加载的时间

# Count the rows again, noting the variance in time of a cached DataFrame
start_time = time.time()
print("Counting %d rows again took %f seconds" % (departures_df.count(), time.time() - start_time))

可以查看是否在缓存里面

查看缓存的时间

# Determine if departures_df is in the cache
print("Is departures_df cached?: %s" % departures_df.is_cached)
print("Removing departures_df from cache")

# Remove departures_df from the cache
departures_df.unpersist()

# Check the cache status again
print("Is departures_df cached?: %s" % departures_df.is_cached)

read config

# Name of the Spark application instance
app_name = spark.conf.get('spark.app.name')

# Driver TCP port
driver_tcp_port = spark.conf.get('spark.driver.port')

# Number of join partitions
num_partitions = spark.conf.get('spark.sql.shuffle.partitions')

# Show the results
print("Name: %s" % app_name)
print("Driver TCP port: %s" % driver_tcp_port)
print("Number of partitions: %s" % num_partitions)

<script.py> output:
    Name: pyspark-shell
    Driver TCP port: 45583
    Number of partitions: 200

wirte config

# Store the number of partitions in variable
before = departures_df.rdd.getNumPartitions()

# Configure Spark to use 500 partitions
spark.conf.set('spark.sql.shuffle.partitions', 500)

# Recreate the DataFrame using the departures data file
departures_df = spark.read.csv('departures.txt.gz').distinct()

# Print the number of partitions for each instance
print("Partition count before change: %d" % before)
print("Partition count after change: %d" % departures_df.rdd.getNumPartitions())

join data

# Join the flights_df and aiports_df DataFrames
normal_df = flights_df.join(airports_df, \
    flights_df["Destination Airport"] == airports_df["IATA"] )

# Show the query plan
normal_df.explain()

dataframe这里有非常多的内置函数,可以参考这里csdn
一些常用的函数,还是需要知道的,这样可以省力

using broadcasting

# Import the broadcast method from pyspark.sql.functions
from pyspark.sql.functions import broadcast

# Join the flights_df and airports_df DataFrames using broadcasting
broadcast_df = flights_df.join(broadcast(airports_df), \
    flights_df["Destination Airport"] == airports_df["IATA"] )

# Show the query plan and compare against the original
broadcast_df.explain()


<script.py> output:
    == Physical Plan ==
    *(2) BroadcastHashJoin [Destination Airport#12], [IATA#29], Inner, BuildRight
    :- *(2) Project [Date (MM/DD/YYYY)#10, Flight Number#11, Destination Airport#12, Actual elapsed time (Minutes)#13]
    :  +- *(2) Filter isnotnull(Destination Airport#12)
    :     +- *(2) FileScan csv [Date (MM/DD/YYYY)#10,Flight Number#11,Destination Airport#12,Actual elapsed time (Minutes)#13] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/tmp/tmpvcaouj4c/AA_DFW_2018_Departures_Short.csv.gz], PartitionFilters: [], PushedFilters: [IsNotNull(Destination Airport)], ReadSchema: struct<Date (MM/DD/YYYY):string,Flight Number:string,Destination Airport:string,Actual elapsed ti...
    +- BroadcastExchange HashedRelationBroadcastMode(List(input[1, string, true]))
       +- *(1) Project [AIRPORTNAME#28, IATA#29]
          +- *(1) Filter isnotnull(IATA#29)
             +- *(1) FileScan csv [AIRPORTNAME#28,IATA#29] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/tmp/tmpvcaouj4c/airportnames.txt.gz], PartitionFilters: [], PushedFilters: [IsNotNull(IATA)], ReadSchema: struct<AIRPORTNAME:string,IATA:string>

流水线式处理数据

# Import the data to a DataFrame
departures_df = spark.read.csv('2015-departures.csv.gz', header=True)

# Remove any duration of 0
departures_df = departures_df.filter(departures_df[3] > 0)

# Add an ID column
departures_df = departures_df.withColumn('id', F.monotonically_increasing_id())

# Write the file out to JSON format
departures_df.write.json('output.json', mode='overwrite')

## 一些数据处理得技巧

```r
# Import the file to a DataFrame and perform a row count
annotations_df = spark.read.csv('annotations.csv.gz', sep='|')
full_count = annotations_df.count()

# Count the number of rows beginning with '#'
comment_count = annotations_df.where(col('_c0').startswith('#')).count()

# Import the file to a new DataFrame, without commented rows
no_comments_df = spark.read.csv('annotations.csv.gz', sep='|', comment='#')

# Count the new DataFrame and verify the difference is as expected
no_comments_count = no_comments_df.count()
print("Full count: %d\nComment count: %d\nRemaining count: %d" % (full_count, comment_count, no_comments_count))

<script.py> output:
    Full count: 32794
    Comment count: 1416
    Remaining count: 31378

删除无效得行

Removing invalid rows

# Split _c0 on the tab character and store the list in a variable
tmp_fields = F.split(annotations_df['_c0'], '\t')

# Create the colcount column on the DataFrame
annotations_df = annotations_df.withColumn('colcount', F.size(tmp_fields))

# Remove any rows containing fewer than 5 fields
annotations_df_filtered = annotations_df.filter(~ (annotations_df["colcount"] < 5))

# Count the number of rows
final_count = annotations_df_filtered.count()
print("Initial count: %d\nFinal count: %d" % (initial_count, final_count))

划分数据集

Split the content of _c0 on the tab character (aka, '\t')

split_cols = F.split(annotations_df["_c0"], '\t')

Add the columns folder, filename, width, and height

split_df = annotations_df.withColumn('folder', split_cols.getItem(0))
split_df = split_df.withColumn('filename', split_cols.getItem(1))
split_df = split_df.withColumn('width', split_cols.getItem(2))
split_df = split_df.withColumn('height', split_cols.getItem(3))

Add split_cols as a column

split_df = split_df.withColumn('split_cols', split_cols)

posted @ 2020-11-04 21:36  高文星星  阅读(1195)  评论(0编辑  收藏  举报