Transform Single Inputs into Tables Using PySpark UDTFs

In PySpark, User-Defined Functions (UDFs) and User-Defined Table Functions (UDTFs) are custom functions that perform complex data transformations.

UDFs take input columns and return a single value. However, they are cumbersome when returning multiple rows and columns, resulting in complex and inefficient code.

from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, StructType, StructField, IntegerType

# Define the schema of the output
schema = ArrayType(
    StructType(
        [
            StructField("num", IntegerType(), False),
            StructField("squared", IntegerType(), False),
        ]
    )
)


# Define the UDF
@udf(returnType=schema)
def square_numbers_udf(start: int, end: int):
    return [(num, num * num) for num in range(start, end + 1)]


# Use in Python
df = spark.createDataFrame([(1, 3)], ["start", "end"])
result_df = df.select(explode(square_numbers_udf(df.start, df.end)).alias("result"))
result_df.select("result.num", "result.squared").show()
+---+-------+
|num|squared|
+---+-------+
|  1|      1|
|  2|      4|
|  3|      9|
+---+-------+

With UDTFs, introduced in Apache Spark 3.5, you can create functions that return entire tables from a single input, making it easier to work with multiple rows and columns.

from pyspark.sql.functions import udtf, lit
from pyspark.sql.types import StringType


@udtf(returnType="num: int, squared: int")
class SquareNumbers:
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)


SquareNumbers(lit(1), lit(3)).show()
+---+-------+
|num|squared|
+---+-------+
|  1|      1|
|  2|      4|
|  3|      9|
+---+-------+

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top

Work with Khuyen Tran

Work with Khuyen Tran