In PySpark, User-Defined Functions (UDFs) and User-Defined Table Functions (UDTFs) are custom functions that perform complex data transformations.
UDFs take input columns and return a single value. However, they are cumbersome when returning multiple rows and columns, resulting in complex and inefficient code.
from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, StructType, StructField, IntegerType
# Define the schema of the output
schema = ArrayType(
StructType(
[
StructField("num", IntegerType(), False),
StructField("squared", IntegerType(), False),
]
)
)
# Define the UDF
@udf(returnType=schema)
def square_numbers_udf(start: int, end: int):
return [(num, num * num) for num in range(start, end + 1)]
# Use in Python
df = spark.createDataFrame([(1, 3)], ["start", "end"])
result_df = df.select(explode(square_numbers_udf(df.start, df.end)).alias("result"))
result_df.select("result.num", "result.squared").show()
+---+-------+
|num|squared|
+---+-------+
| 1| 1|
| 2| 4|
| 3| 9|
+---+-------+
With UDTFs, introduced in Apache Spark 3.5, you can create functions that return entire tables from a single input, making it easier to work with multiple rows and columns.
from pyspark.sql.functions import udtf, lit
from pyspark.sql.types import StringType
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
SquareNumbers(lit(1), lit(3)).show()
+---+-------+
|num|squared|
+---+-------+
| 1| 1|
| 2| 4|
| 3| 9|
+---+-------+