Use PySpark UDFs to Make SQL Logic Reusable

Use PySpark UDFs to Make SQL Logic Reusable

Motivation

Complex SQL queries often involve repetitive calculations and nested subqueries that make code maintenance difficult and prone to errors. When dealing with large-scale data processing, data engineers frequently need to rewrite the same logic multiple times within their queries.

Consider a scenario where you need to repeat complex CASE statements across different queries:

customers_df = spark.createDataFrame([
    (1, "John", 25, 60000),
    (2, "Jane", 17, 0),
    (3, "Bob", 68, 45000)
], ["customer_id", "name", "age", "income"])

# Register the DataFrame as a temporary table
customers_df.createOrReplaceTempView("customers")

# Duplicated CASE logic across queries
query1 = spark.sql("""
    SELECT customer_id,
        CASE 
            WHEN age < 18 THEN 'minor'
            WHEN age > 65 THEN 'senior'
            WHEN income > 50000 THEN 'prime'
            ELSE 'standard'
        END as segment
    FROM customers
""")

query2 = spark.sql("""
    SELECT CASE 
            WHEN age < 18 THEN 'minor'
            WHEN age > 65 THEN 'senior'
            WHEN income > 50000 THEN 'prime'
            ELSE 'standard'
        END as segment,
    COUNT(*) as count
    FROM customers
    GROUP BY CASE 
        WHEN age < 18 THEN 'minor'
        WHEN age > 65 THEN 'senior'
        WHEN income > 50000 THEN 'prime'
        ELSE 'standard'
    END
""")
query1.show()

Output:

+-----------+----------------+
|customer_id|customer_segment|
+-----------+----------------+
|          1|           prime|
|          2|           minor|
|          3|          senior|
+-----------+----------------+
query2.show()

Output:

+-------+-----+
|segment|count|
+-------+-----+
|  prime|    1|
|  minor|    1|
| senior|    1|
+-------+-----+

Introduction to PySpark

PySpark is Apache Spark’s Python API that enables you to write reusable Python functions for use in SQL queries. Install PySpark:

pip install pyspark[sql]

Reducing Duplication with UDFs

Instead of repeating complex CASE statements, create a single UDF:

from pyspark.sql.types import StringType

# Define the segmentation logic once
def segment_customers(age, income):
    if age is None or income is None:
        return None
    if age < 18:
        return "minor"
    elif age > 65:
        return "senior"
    elif income > 50000:
        return "prime"
    return "standard"

# Register UDF with explicit return type
spark.udf.register("segment_customers", segment_customers, StringType())

Now you can reuse this logic across multiple queries:

# Query 1: Simple segmentation
query1 = spark.sql("""
    SELECT 
        customer_id,
        segment_customers(age, income) AS segment
    FROM customers
""")

# Query 2: Segment counts
query2 = spark.sql("""
    SELECT 
        segment_customers(age, income) AS segment,
        COUNT(*) as count
    FROM customers
    GROUP BY segment_customers(age, income)
""")

Conclusion

PySpark UDFs provide a powerful way to reduce code duplication and maintain consistency in complex SQL queries. By centralizing business logic in well-documented, reusable functions, you can write clearer, more maintainable code while ensuring consistent implementation across your entire application.

Link to PySpark

Search

Scroll to Top

Work with Khuyen Tran

Work with Khuyen Tran