Motivation
Complex SQL queries often involve repetitive calculations and nested subqueries that make code maintenance difficult and prone to errors. When dealing with large-scale data processing, data engineers frequently need to rewrite the same logic multiple times within their queries.
Consider a scenario where you need to repeat complex CASE statements across different queries:
customers_df = spark.createDataFrame([
(1, "John", 25, 60000),
(2, "Jane", 17, 0),
(3, "Bob", 68, 45000)
], ["customer_id", "name", "age", "income"])
# Register the DataFrame as a temporary table
customers_df.createOrReplaceTempView("customers")
# Duplicated CASE logic across queries
query1 = spark.sql("""
SELECT customer_id,
CASE
WHEN age < 18 THEN 'minor'
WHEN age > 65 THEN 'senior'
WHEN income > 50000 THEN 'prime'
ELSE 'standard'
END as segment
FROM customers
""")
query2 = spark.sql("""
SELECT CASE
WHEN age < 18 THEN 'minor'
WHEN age > 65 THEN 'senior'
WHEN income > 50000 THEN 'prime'
ELSE 'standard'
END as segment,
COUNT(*) as count
FROM customers
GROUP BY CASE
WHEN age < 18 THEN 'minor'
WHEN age > 65 THEN 'senior'
WHEN income > 50000 THEN 'prime'
ELSE 'standard'
END
""")
query1.show()
Output:
+-----------+----------------+
|customer_id|customer_segment|
+-----------+----------------+
| 1| prime|
| 2| minor|
| 3| senior|
+-----------+----------------+
query2.show()
Output:
+-------+-----+
|segment|count|
+-------+-----+
| prime| 1|
| minor| 1|
| senior| 1|
+-------+-----+
Introduction to PySpark
PySpark is Apache Spark’s Python API that enables you to write reusable Python functions for use in SQL queries. Install PySpark:
pip install pyspark[sql]
Reducing Duplication with UDFs
Instead of repeating complex CASE statements, create a single UDF:
from pyspark.sql.types import StringType
# Define the segmentation logic once
def segment_customers(age, income):
if age is None or income is None:
return None
if age < 18:
return "minor"
elif age > 65:
return "senior"
elif income > 50000:
return "prime"
return "standard"
# Register UDF with explicit return type
spark.udf.register("segment_customers", segment_customers, StringType())
Now you can reuse this logic across multiple queries:
# Query 1: Simple segmentation
query1 = spark.sql("""
SELECT
customer_id,
segment_customers(age, income) AS segment
FROM customers
""")
# Query 2: Segment counts
query2 = spark.sql("""
SELECT
segment_customers(age, income) AS segment,
COUNT(*) as count
FROM customers
GROUP BY segment_customers(age, income)
""")
Conclusion
PySpark UDFs provide a powerful way to reduce code duplication and maintain consistency in complex SQL queries. By centralizing business logic in well-documented, reusable functions, you can write clearer, more maintainable code while ensuring consistent implementation across your entire application.