collect_list and collect_set in Pyspark


collect_list and collect_set in Pyspark

from pyspark.sql import functions as F

df = sc.parallelize([("a", "b", 1),("a", "b", 2),("a", "b", 3),("z", "b", 4),("a", "x", 5)]).toDF(["A", "B", "C"])

df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  b|  1|
|  a|  b|  2|
|  a|  b|  3|
|  z|  b|  4|
|  a|  x|  5|
+---+---+---+

collect_list() method is to eliminate all the rows with duplicate A and B rows in the DataFrame and collect all the C entries as a list

df.groupBy("A", "B").agg(F.collect_list("C")).show()

+---+---+---------------+
|  A|  B|collect_list(C)|
+---+---+---------------+
|  a|  x|            [5]|
|  z|  b|            [4]|
|  a|  b|      [1, 2, 3]|
+---+---+---------------+

collect_set() method is used to aggregate unique records and eliminate duplicates - (collect_set() lets's us retail all the valuable information and delete the duplicates)

df = sc.parallelize([("123", "20180102", 10.49),("123", "20180102", 10.49),("123", "20180102", 77.33),("555", "20180214", 99.99),("888", "20180214", 1.23)]).toDF(["id", "date", "values"])

df.show()
+---+--------+------+
| id|    date|values|
+---+--------+------+
|123|20180102| 10.49|
|123|20180102| 10.49|
|123|20180102| 77.33|
|555|20180214| 99.99|
|888|20180214|  1.23|
+---+--------+------+

df.groupBy("id","date").agg(F.collect_set("values")).show()

+---+--------+-------------------+
| id|    date|collect_set(values)|
+---+--------+-------------------+
|555|20180214|            [99.99]|
|888|20180214|             [1.23]|
|123|20180102|     [10.49, 77.33]|
+---+--------+-------------------+

Data on Hockey Players and how many goals they have scored in each game:-


data = sc.parallelize([("123", 11, "20180102", 0),("123", 11, "20180102", 0),("123", 13, "20180105", 3),("555", 11, "20180214", 1),("888", 22, "20180214", 2)]).toDF(["player_id", "game_id", "game_date", "goals"])


+---------+-------+---------+-----+
|player_id|game_id|game_date|goals|
+---------+-------+---------+-----+
|      123|     11| 20180102|    0|
|      123|     11| 20180102|    0|
|      123|     13| 20180105|    3|
|      555|     11| 20180214|    1|
|      888|     22| 20180214|    2|
+---------+-------+---------+-----+


data.withColumn("as_struct", struct("game_id", "game_date", "goals")).groupBy("player_id").agg(F.collect_set("as_struct")).show(truncate = False)


+---------+--------------------------------------+
|player_id|collect_set(as_struct)                |
+---------+--------------------------------------+
|888      |[[22, 20180214, 2]]                   |
|555      |[[11, 20180214, 1]]                   |
|123      |[[11, 20180102, 0], [13, 20180105, 3]]|
+---------+--------------------------------------+



*********** Another Practical Scenario *******************

import pyspark.sql.functions as F

from pyspark.sql.functions import collect_list

from pyspark.sql.functions import concat, concat_ws, struct


input = [[123,'XYZ',1,'A'],[123,'XYZ',1,'B'],[123,'XYZ',2,'C'],[123,'XYZ',2,'D']]


df = spark.createDataFrame(input,["PatientNumber","PatientSSN","PatientIntakeId","PatientSubscriberId"])


df.show()

+-------------+----------+---------------+-------------------+

|PatientNumber|PatientSSN|PatientIntakeId|PatientSubscriberId|

+-------------+----------+---------------+-------------------+

|          123|       XYZ|              1|                  A|

|          123|       XYZ|              1|                  B|

|          123|       XYZ|              2|                  C|

|          123|       XYZ|              2|                  D|

+-------------+----------+---------------+-------------------+


df2 = df.groupBy('PatientNumber','PatientSSN','PatientIntakeId').agg(collect_list('PatientSubscriberId').alias('PatientSubscriberId'))

 df2.show()
+-------------+----------+---------------+-------------------+
|PatientNumber|PatientSSN|PatientIntakeId|PatientSubscriberId|
+-------------+----------+---------------+-------------------+
|          123|       XYZ|              1|             [A, B]|
|          123|       XYZ|              2|             [C, D]|
+-------------+----------+---------------+-------------------+

df3 = df.groupBy('PatientNumber','PatientSSN').agg(collect_list(struct('PatientIntakeId','PatientSubscriberId')).alias('col_list'))


df3.show(truncate=False)


+-------------+----------+--------------------------------+

|PatientNumber|PatientSSN|col_list                        |

+-------------+----------+--------------------------------+

|123          |XYZ       |[[1, A], [1, B], [2, C], [2, D]]|

+-------------+----------+--------------------------------+











No comments:

Post a Comment

Pages