Spark两种方法计算分组取TopN

database

Spark 分组取Top N运算

大数据处理中,对数据分组后,取TopN是非常常见的运算。

下面我们以一个例子来展示spark如何进行分组取Top的运算。

1、RDD方法分组取TopN

from pyspark import SparkContext

sc = SparkContext()

准备数据,把数据转换为rdd格式

data_list = [

(0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),

(1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),

(2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),

(3, "cat8", 135.6)

]

data = sc.parallelize(data_list)

data.collect()

[(0, "cat26", 130.9),

(0, "cat13", 122.1),

(0, "cat95", 119.6),

(0, "cat105", 11.3),

(1, "cat67", 128.5),

(1, "cat4", 126.8),

(1, "cat13", 112.6),

(1, "cat23", 15.3),

(2, "cat56", 139.6),

(2, "cat40", 129.7),

(2, "cat187", 127.9),

(2, "cat68", 19.8),

(3, "cat8", 135.6)]

对数据使用groupBy操作来分组。可以看到分组后数据为(key, list_data)

d1 = data.groupBy(lambda x:x[0])

temp = d1.collect()

print(list(temp[0][1]))

print(temp)

[(0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3)]

[(0, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C710>), (1, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C780>), (2, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C898>), (3, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C9B0>)]

使用mapValues方法对数据进行排序。

可以根据需要来取Top N 数据。

这里取Top 3 的数据

d2 = d1.mapValues(lambda x: sorted(x, key=lambda y:y[2])[:3])

d2.collect()

[(0, [(0, "cat105", 11.3), (0, "cat95", 119.6), (0, "cat13", 122.1)]),

(1, [(1, "cat23", 15.3), (1, "cat13", 112.6), (1, "cat4", 126.8)]),

(2, [(2, "cat68", 19.8), (2, "cat187", 127.9), (2, "cat40", 129.7)]),

(3, [(3, "cat8", 135.6)])]

使用flatmap方法把结果拉平,变成一个list返回。

d3 = d2.flatMap(lambda x:[i for i in x[1]])

d3.collect()

[(0, "cat105", 11.3),

(0, "cat95", 119.6),

(0, "cat13", 122.1),

(1, "cat23", 15.3),

(1, "cat13", 112.6),

(1, "cat4", 126.8),

(2, "cat68", 19.8),

(2, "cat187", 127.9),

(2, "cat40", 129.7),

(3, "cat8", 135.6)]

整体代码

from pyspark import SparkContext

# sc = SparkContext()

topN = 3

data_list = [

(0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),

(1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),

(2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),

(3, "cat8", 135.6)

]

data = sc.parallelize(data_list)

d1 = data.groupBy(lambda x:x[0])

d2 = d1.mapValues(lambda x: sorted(x, key=lambda y:y[2])[:topN])

d3 = d2.flatMap(lambda x:[i for i in x[1]])

d3.collect()

[(0, "cat105", 11.3),

(0, "cat95", 119.6),

(0, "cat13", 122.1),

(1, "cat23", 15.3),

(1, "cat13", 112.6),

(1, "cat4", 126.8),

(2, "cat68", 19.8),

(2, "cat187", 127.9),

(2, "cat40", 129.7),

(3, "cat8", 135.6)]

2、Dataframe方法分组取TopN

dataframe数据格式分组取top N,简单的方法是使用Window方法

from pyspark.sql import SparkSession

from pyspark.sql import functions as func

from pyspark.sql import Window

spark = SparkSession.builder.getOrCreate()

data_list = [

(0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),

(1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),

(2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),

(3, "cat8", 135.6)

]

根据数据创建dataframe,并给数据列命名

df = spark.createDataFrame(data_list, ["Hour", "Category", "TotalValue"])

df.show()

+----+--------+----------+

|Hour|Category|TotalValue|

+----+--------+----------+

| 0| cat26| 130.9|

| 0| cat13| 122.1|

| 0| cat95| 119.6|

| 0| cat105| 11.3|

| 1| cat67| 128.5|

| 1| cat4| 126.8|

| 1| cat13| 112.6|

| 1| cat23| 15.3|

| 2| cat56| 139.6|

| 2| cat40| 129.7|

| 2| cat187| 127.9|

| 2| cat68| 19.8|

| 3| cat8| 135.6|

+----+--------+----------+

  1. 使用窗口方法,分片参数为分组的key,

  2. orderBy的参数为排序的key,这里使用desc降序排列。

  3. withColumn(colName, col),为df添加一列,数据为对window函数生成的数据编号

  4. where方法取rn列值小于3的数据,即取top3数据

w = Window.partitionBy(df.Hour).orderBy(df.TotalValue.desc())

top3 = df.withColumn("rn", func.row_number().over(w)).where("rn <=3")

top3.show()

+----+--------+----------+---+

|Hour|Category|TotalValue| rn|

+----+--------+----------+---+

| 0| cat26| 130.9| 1|

| 0| cat13| 122.1| 2|

| 0| cat95| 119.6| 3|

| 1| cat67| 128.5| 1|

| 1| cat4| 126.8| 2|

| 1| cat13| 112.6| 3|

| 3| cat8| 135.6| 1|

| 2| cat56| 139.6| 1|

| 2| cat40| 129.7| 2|

| 2| cat187| 127.9| 3|

+----+--------+----------+---+

### 代码汇总

from pyspark.sql import SparkSession

from pyspark.sql import functions as func

from pyspark.sql import Window

spark = SparkSession.builder.getOrCreate()

data_list = [

(0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),

(1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),

(2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),

(3, "cat8", 135.6)

]

df = spark.createDataFrame(data_list, ["Hour", "Category", "TotalValue"])

w = Window.partitionBy(df.Hour).orderBy(df.TotalValue.desc())

top3 = df.withColumn("rn", func.row_number().over(w)).where("rn <=3")

top3.show()

以上是 Spark两种方法计算分组取TopN 的全部内容, 来源链接: utcz.com/z/534450.html

回到顶部