What is PySpark broadcast join?
PySpark broadcast join is a method used in PySpark (a Python library for Apache Spark) to improve joint operation performance when one of the joined tables is tiny. The primary goal of a broadcast join is to eliminate data shuffling and network overhead associated with join operations, which can result in considerable speed benefits. A broadcast join sends the smaller table (or DataFrame) to all worker nodes, ensuring each worker node has a complete copy of the smaller table in memory. This allows the join operation to be conducted locally on each worker node, eliminating the network's data shuffle and transfer requirement.
We can use the broadcast() method from the pyspark.sql.functions module to use broadcast joins in PySpark.
Let us design two sample tables in PySpark to show the broadcast join. We'll utilize a combination of large table sales and smaller table products:
Sales Table
order_id | product_id | product_id |
1 | 101 | 2 |
2 | 102 | 1 |
3 | 103 | 3 |
4 | 101 | 1 |
5 | 104 | 4 |
Products Table
product_id | product_name | price |
101 | Learn C++ | 910 |
102 | Mobile: X1 | 14000 |
103 | LCD | 8000 |
104 | Laptop | 25000 |
Now, let's try a broadcast join in PySpark with the tables above:
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize the Spark session
spark = SparkSession.builder.appName("Broadcast Join Example").getOrCreate()
# Create DataFrames from sample data
sales_data = [(1, 101, 2), (2, 102, 1), (3, 103, 3), (4, 101, 1), (5, 104, 4)]
products_data = [(101, "Learn C++", 10), (102, "Mobile: X1", 20), (103, "LCD", 30), (104, "Laptop", 40)]
sales_columns = ["order_id", "product_id", "quantity"]
products_columns = ["product_id", "product_name", "price"]
sales_df = spark.createDataFrame(sales_data, schema=sales_columns)
products_df = spark.createDataFrame(products_data, schema=products_columns)
# Perform broadcast join
result = sales_df.join(broadcast(products_df), sales_df["product_id"] == products_df["product_id"])
# Show result
result.show()Explanation
This PySpark code performs a broadcast join between two DataFrames, sales_df and products_df, using the "product_id" column as the key. Here's the explanation of each part:
Lines 1–2: Import necessary modules from PySpark:
SparkSessionandbroadcastfunction.Line 5: Initialize a
SparkSessionwith the name"Broadcast Join Example".Lines 8–9: Create sample
salesandproductdata as lists of tuples.Lines 11–12: Define column names for the
salesandproductsDataFrames.Lines 14–15: Create the
salesandproductsDataFrames using the sample data and column names.Line 18: Perform a broadcast join between the
salesandproductsDataFrames using the"product_id"column as the key. The broadcast function is used to hint that the smaller DataFrame (products_df) should be"product_id"broadcast to all worker nodes, optimizing the join performance.Line 21: Show the result of the join by calling the
show()method on the resulting DataFrame.
The result of the broadcast join above will be as follows:
Sales Table
order_id | product_id | quantity | product_id | product_name | price |
1 | 101 | 2 | 101 | Learn C++ | 10 |
2 | 102 | 1 | 102 | Mobile: X1 | 20 |
3 | 103 | 3 | 103 | LCD | 30 |
4 | 101 | 1 | 101 | Learn C++ | 10 |
5 | 104 | 4 | 104 | Laptop | 40 |
Free Resources