Generic selectors
Exact matches only
Search in title
Search in content
Post Type Selectors
Filter by Categories
About Article
Analyze Data
Archive
Best Practices
Better Outputs
Blog
Code Optimization
Code Quality
Command Line
Daily tips
Dashboard
Data Analysis & Manipulation
Data Engineer
Data Visualization
DataFrame
Delta Lake
DevOps
DuckDB
Environment Management
Feature Engineer
Git
Jupyter Notebook
LLM
LLM Tools
Machine Learning
Machine Learning & AI
Machine Learning Tools
Manage Data
MLOps
Natural Language Processing
Newsletter Archive
NumPy
Pandas
Polars
PySpark
Python Helpers
Python Tips
Python Utilities
Scrape Data
SQL
Testing
Time Series
Tools
Visualization
Visualization & Reporting
Workflow & Automation
Workflow Automation

The Complete PySpark SQL Guide: DataFrames, Aggregations, Window Functions, and Pandas UDFs

The Complete PySpark SQL Guide: DataFrames, Aggregations, Window Functions, and Pandas UDFs

Table of Contents

Introduction

pandas works great for prototyping but fails when datasets grow beyond memory capacity. While PySpark offers distributed computing to handle massive datasets, mastering its new syntax and rewriting existing code creates a steep barrier to adoption.

PySpark SQL bridges this gap by offering SQL-style DataFrame operations and query strings, eliminating the need to learn PySpark’s lower-level RDD programming model and functional transformations.

In this comprehensive guide, you’ll learn PySpark SQL from the ground up:

  • Load, explore, and manipulate DataFrames with selection and filtering operations
  • Aggregate data and work with strings, dates, and time series
  • Use window functions for rankings, running totals, and moving averages
  • Execute SQL queries alongside DataFrame operations
  • Create custom functions with pandas UDFs for vectorized performance

💻 Get the Code: The complete source code and Jupyter notebook for this tutorial are available on GitHub. Clone it to follow along!

Getting Started

First, install PySpark:

pip install pyspark

Create a SparkSession to start working with PySpark:

from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
    .appName("PySpark SQL Guide")
    .getOrCreate()
)

The SparkSession is your entry point to all PySpark functionality.

Creating DataFrames

PySpark supports creating DataFrames from multiple sources including Python objects, pandas DataFrames, files, and databases.

Create from Python dictionaries:

data = {
    "customer_id": [1, 2, 3, 4, 5],
    "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"],
    "region": ["North", "South", "North", "East", "West"],
    "amount": [100, 150, 200, 120, 180]
}

# Create DataFrame from Python dictionary using zip and tuples
df = spark.createDataFrame(
    [(k, v1, v2, v3) for k, v1, v2, v3 in zip(
        data["customer_id"],
        data["name"],
        data["region"],
        data["amount"]
    )],
    ["customer_id", "name", "region", "amount"]
)

Convert from pandas:

import pandas as pd

pandas_df = pd.DataFrame(data)
# Convert pandas DataFrame to PySpark DataFrame
df = spark.createDataFrame(pandas_df)
df.show()

Load from CSV files:

# Load CSV file with automatic schema inference
df = spark.read.csv("data.csv", header=True, inferSchema=True)

Read with schema specification:

from pyspark.sql.types import StructType, StructField, StringType, LongType

schema = StructType([
    StructField("customer_id", LongType(), True),
    StructField("name", StringType(), True),
    StructField("region", StringType(), True),
    StructField("amount", LongType(), True)
])

# Load CSV file with explicit schema definition
df = spark.read.csv("data.csv", header=True, schema=schema)

Understanding Lazy Evaluation

PySpark’s execution model differs fundamentally from pandas. Operations are divided into two types.

Transformations are lazy operations that build execution plans without running:

# Transformations return immediately
filtered = df.filter(df.amount > 100)
print(f"Filtered: {filtered}")

selected = filtered.select("name", "amount")
print(f"Selected: {selected}")

Output:

Filtered: DataFrame[customer_id: bigint, name: string, region: string, amount: bigint]
Selected: DataFrame[name: string, amount: bigint]

Common transformations include select(), filter(), withColumn(), and groupBy(). They return instantly because they only build an execution plan and can be chained without performance cost.

Actions trigger execution and return actual results:

selected.show()

Output:

+-------+------+
|   name|amount|
+-------+------+
|    Bob|   150|
|Charlie|   200|
|   Diana|   120|
|    Eve|   180|
+-------+------+

Common actions include show(), collect(), count(), and describe(). They execute the entire chain of transformations and return actual results.

This lazy evaluation enables Spark’s Catalyst optimizer to analyze your complete workflow and apply optimizations like predicate pushdown and column pruning before execution.

Data Exploration

Data exploration in PySpark works similarly to pandas, but with methods designed for distributed computing. Instead of pandas’ df.info() and df.head(), PySpark uses printSchema() and show() to inspect schemas and preview records across the cluster.

View the schema:

df.printSchema()

Output:

root
 |-- customer_id: long (nullable = true)
 |-- name: string (nullable = true)
 |-- region: string (nullable = true)
 |-- amount: long (nullable = true)

Preview the first few rows:

# Display the first 5 rows of the DataFrame
df.show(5)

Output:

+-----------+-------+------+------+
|customer_id|   name|region|amount|
+-----------+-------+------+------+
|          1|  Alice| North|   100|
|          2|    Bob| South|   150|
|          3|Charlie| North|   200|
|          4|  Diana|  East|   120|
|          5|    Eve|  West|   180|
+-----------+-------+------+------+

Get summary statistics:

# Generate summary statistics for numeric columns
df.describe().show()

Output:

+-------+-----------+-------+------+------------------+
|summary|customer_id|   name|region|            amount|
+-------+-----------+-------+------+------------------+
|  count|          5|      5|     5|                 5|
|   mean|        3.0|   null|  null|             150.0|
| stddev|       1.58|   null|  null|40.311288741492746|
|    min|          1|  Alice|  East|               100|
|    max|          5|    Eve|  West|               200|
+-------+-----------+-------+------+------------------+

Count total rows:

df.count()

Output:

5

List column names:

df.columns

Output:

['customer_id', 'name', 'region', 'amount']

Get distinct values in a column:

# Get all unique values in the region column
df.select("region").distinct().show()

Output:

+------+
|region|
+------+
| South|
|  East|
|  West|
| North|
+------+

Sample random rows:

# Randomly sample 60% of the rows
df.sample(fraction=0.6, seed=42).show()

Output:

+-----------+-----+------+------+
|customer_id| name|region|amount|
+-----------+-----+------+------+
|          2|  Bob| South|   150|
|          4|Diana|  East|   120|
+-----------+-----+------+------+

Selection & Filtering

When selecting and filtering data, PySpark uses explicit methods like select() and filter() that build distributed execution plans.

Select specific columns:

# Select columns name and amount
df.select("name", "amount").show()

Output:

+-------+------+
|   name|amount|
+-------+------+
|  Alice|   100|
|    Bob|   150|
|Charlie|   200|
|  Diana|   120|
|    Eve|   180|
+-------+------+

Filter rows with conditions:

# Filter rows where amount is greater than 150
df.filter(df.amount > 150).show()

Output:

+-----------+-------+------+------+
|customer_id|   name|region|amount|
+-----------+-------+------+------+
|          3|Charlie| North|   200|
|          5|    Eve|  West|   180|
+-----------+-------+------+------+

Chain multiple filters:

# Get rows where amount is greater than 100 and region is North
(
    df.filter(df.amount > 100)
  .filter(df.region == "North")
  .show()
)

Output:

+-----------+-------+------+------+
|customer_id|   name|region|amount|
+-----------+-------+------+------+
|          3|Charlie| North|   200|
+-----------+-------+------+------+

Drop columns:

# Drop the customer_id column
df.drop("customer_id").show()

Output:

+-------+------+------+
|   name|region|amount|
+-------+------+------+
|  Alice| North|   100|
|    Bob| South|   150|
|Charlie| North|   200|
|  Diana|  East|   120|
|    Eve|  West|   180|
+-------+------+------+

Column Operations

Unlike pandas’ mutable operations where df['new_col'] modifies the DataFrame in place, PySpark’s withColumn() and withColumnRenamed() return new DataFrames, maintaining the distributed computing model.

The withColumn() method takes two arguments: the new column name and an expression defining its values:

from pyspark.sql.functions import col

# Add a new column with the amount with tax
df.withColumn("amount_with_tax", col("amount") * 1.1).show()

Output:

+-----------+-------+------+------+------------------+
|customer_id|   name|region|amount|   amount_with_tax|
+-----------+-------+------+------+------------------+
|          1|  Alice| North|   100|110.00000000000001|
|          2|    Bob| South|   150|             165.0|
|          3|Charlie| North|   200|220.00000000000003|
|          4|  Diana|  East|   120|             132.0|
|          5|    Eve|  West|   180|198.00000000000003|
+-----------+-------+------+------+------------------+

Add constant value columns with lit():

from pyspark.sql.functions import lit

# Add a column with the same value for all rows
df.withColumn("country", lit("USA")).select("name", "amount", "country").show()

Output:

+-------+------+-------+
|   name|amount|country|
+-------+------+-------+
|  Alice|   100|    USA|
|    Bob|   150|    USA|
|Charlie|   200|    USA|
|  Diana|   120|    USA|
|    Eve|   180|    USA|
+-------+------+-------+

Use withColumnRenamed() to rename a column by specifying the old name and new name:

# Rename the amount column to revenue
df.withColumnRenamed("amount", "revenue").show()

Output:

+-----------+-------+------+-------+
|customer_id|   name|region|revenue|
+-----------+-------+------+-------+
|          1|  Alice| North|    100|
|          2|    Bob| South|    150|
|          3|Charlie| North|    200|
|          4|  Diana|  East|    120|
|          5|    Eve|  West|    180|
+-----------+-------+------+-------+

Use the cast() method to convert a column to a different data type:

# Cast the amount column to a string
df.withColumn("amount_str", col("amount").cast("string")).printSchema()

Output:

root
 |-- customer_id: long (nullable = true)
 |-- name: string (nullable = true)
 |-- region: string (nullable = true)
 |-- amount: long (nullable = true)
 |-- amount_str: string (nullable = true)

Aggregation Functions

Unlike pandas’ in-memory aggregations, PySpark’s groupBy() and aggregation functions distribute calculations across cluster nodes, using the same conceptual model as pandas but with lazy evaluation.

Apply aggregation functions directly in select() to compute values across all rows without grouping:

from pyspark.sql.functions import sum, avg, count, max, min

# Calculate total revenue, average revenue, and count across all rows
df.select(
    sum("amount").alias("total_revenue"),
    avg("amount").alias("avg_revenue"),
    count("*").alias("order_count")
).show()

Output:

+-------------+-----------+-----------+
|total_revenue|avg_revenue|order_count|
+-------------+-----------+-----------+
|          750|      150.0|          5|
+-------------+-----------+-----------+

Combine groupBy() to create groups and agg() to compute multiple aggregations per group:

# Calculate total revenue and customer count by region
(
    df.groupBy("region")
    .agg(
        sum("amount").alias("total_revenue"),
        count("*").alias("customer_count")
    )
    .show()
)

Output:

+------+-------------+--------------+
|region|total_revenue|customer_count|
+------+-------------+--------------+
| North|          300|             2|
|  East|          120|             1|
| South|          150|             1|
|  West|          180|             1|
+------+-------------+--------------+

Combine groupBy() with collect_list() to create arrays of values for each group:

from pyspark.sql.functions import collect_list

# Collect customer names into an array for each region
(
    df.groupBy("region")
    .agg(collect_list("name").alias("customers"))
    .show(truncate=False)
)

Output:

+------+----------------+
|region|customers       |
+------+----------------+
|South |[Bob]           |
|East  |[Diana]         |
|West  |[Eve]           |
|North |[Alice, Charlie]|
+------+----------------+

String Functions

Unlike pandas’ vectorized string methods accessed via .str, PySpark provides importable functions like concat(), split(), and regexp_replace() that transform entire columns across distributed partitions.

Use concat() to combine multiple columns and literal strings, wrapping constant values with lit():

from pyspark.sql.functions import concat, lit

# Concatenate customer name and region with a separator
df.withColumn("full_info", concat(col("name"), lit(" - "), col("region"))).show()

Output:

+-----------+-------+------+------+----------------+
|customer_id|   name|region|amount|       full_info|
+-----------+-------+------+------+----------------+
|          1|  Alice| North|   100|  Alice - North|
|          2|    Bob| South|   150|   Bob - South|
|          3|Charlie| North|   200|Charlie - North|
|          4|  Diana|  East|   120|   Diana - East|
|          5|    Eve|  West|   180|    Eve - West|
+-----------+-------+------+------+----------------+

Use split() to divide a string column into an array based on a delimiter pattern:

from pyspark.sql.functions import split

# Create sample data with email addresses
email_data = spark.createDataFrame(
    [("alice@company.com",), ("bob@startup.io",), ("charlie@corp.net",)],
    ["email"]
)

# Split email into username and domain
(
    email_data.withColumn("email_parts", split(col("email"), "@"))
    .select("email", "email_parts")
    .show(truncate=False)
)

Output:

+-----------------+----------------------+
|email            |email_parts           |
+-----------------+----------------------+
|alice@company.com|[alice, company.com]  |
|bob@startup.io   |[bob, startup.io]     |
|charlie@corp.net |[charlie, corp.net]   |
+-----------------+----------------------+

Use regexp_replace() to find and replace text patterns using regular expressions:

from pyspark.sql.functions import regexp_replace

# Create sample data with phone numbers
phone_data = spark.createDataFrame(
    [("Alice", "123-456-7890"), ("Bob", "987-654-3210"), ("Charlie", "555-123-4567")],
    ["name", "phone"]
)

# Mask phone numbers, keeping only last 4 digits
(
    phone_data.withColumn("masked_phone", regexp_replace(col("phone"), r"\d{3}-\d{3}-(\d{4})", "XXX-XXX-$1"))
    .select("name", "phone", "masked_phone")
    .show()
)

Output:

+-------+--------------+-------------+
|   name|         phone| masked_phone|
+-------+--------------+-------------+
|  Alice|123-456-7890|XXX-XXX-7890|
|    Bob|987-654-3210|XXX-XXX-3210|
|Charlie|555-123-4567|XXX-XXX-4567|
+-------+--------------+-------------+

Date/Time Functions

Working with dates and timestamps is essential for time-based analysis. PySpark offers comprehensive functions to extract date components, format timestamps, and perform temporal operations.

Create sample data with dates:

from datetime import datetime, timedelta

date_data = [
    (1, datetime(2024, 1, 15), 100),
    (2, datetime(2024, 2, 20), 150),
    (3, datetime(2024, 3, 10), 200),
    (4, datetime(2024, 4, 5), 120),
    (5, datetime(2024, 5, 25), 180)
]

# Create sample DataFrame with datetime values
df_dates = spark.createDataFrame(date_data, ["id", "order_date", "amount"])

Use functions like year(), month(), and dayofmonth() to extract individual date components from timestamp columns:

from pyspark.sql.functions import year, month, dayofmonth

# Extract year, month, and day components from order_date
(
    df_dates.withColumn("year", year("order_date"))
    .withColumn("month", month("order_date"))
    .withColumn("day", dayofmonth("order_date"))
    .show()
)

Output:

+---+-------------------+------+----+-----+---+
| id|         order_date|amount|year|month|day|
+---+-------------------+------+----+-----+---+
|  1|2024-01-15 00:00:00|   100|2024|    1| 15|
|  2|2024-02-20 00:00:00|   150|2024|    2| 20|
|  3|2024-03-10 00:00:00|   200|2024|    3| 10|
|  4|2024-04-05 00:00:00|   120|2024|    4|  5|
|  5|2024-05-25 00:00:00|   180|2024|    5| 25|
+---+-------------------+------+----+-----+---+

Use date_format() to convert dates to custom string formats:

from pyspark.sql.functions import date_format

# Format timestamps as YYYY-MM-DD strings
(
    df_dates.withColumn("formatted_date", date_format("order_date", "yyyy-MM-dd"))
    .select("order_date", "formatted_date")
    .show()
)

Output:

+-------------------+--------------+
|         order_date|formatted_date|
+-------------------+--------------+
|2024-01-15 00:00:00|    2024-01-15|
|2024-02-20 00:00:00|    2024-02-20|
|2024-03-10 00:00:00|    2024-03-10|
|2024-04-05 00:00:00|    2024-04-05|
|2024-05-25 00:00:00|    2024-05-25|
+-------------------+--------------+

Use to_timestamp() to convert string columns to timestamp objects by specifying the date format pattern:

from pyspark.sql.functions import to_timestamp

string_dates = spark.createDataFrame(
    [("2024-01-15",), ("2024-02-20",)],
    ["date_string"]
)

# Convert date strings to timestamp objects
string_dates.withColumn(
    "timestamp",
    to_timestamp("date_string", "yyyy-MM-dd")
).show()

Output:

+-----------+-------------------+
|date_string|          timestamp|
+-----------+-------------------+
| 2024-01-15|2024-01-15 00:00:00|
| 2024-02-20|2024-02-20 00:00:00|
+-----------+-------------------+

Working with Time Series

Time series analysis often requires comparing values across different time periods. PySpark’s window functions with lag and lead operations enable calculations of changes and trends over time.

Create sample time series data:

ts_data = [
    (1, datetime(2024, 1, 1), 100),
    (1, datetime(2024, 1, 2), 120),
    (1, datetime(2024, 1, 3), 110),
    (2, datetime(2024, 1, 1), 200),
    (2, datetime(2024, 1, 2), 220),
    (2, datetime(2024, 1, 3), 210)
]

# Create time series data with multiple dates per customer
df_ts = spark.createDataFrame(ts_data, ["customer_id", "date", "amount"])

Calculate the previous row’s value within each customer group using lag():

from pyspark.sql.window import Window
from pyspark.sql.functions import lag

# Create a window: group by customer_id, order by date
window_spec = Window.partitionBy("customer_id").orderBy("date")

# Get the previous amount for each customer
df_ts.withColumn("prev_amount", lag("amount").over(window_spec)).show()

Output:

+-----------+-------------------+------+-----------+
|customer_id|               date|amount|prev_amount|
+-----------+-------------------+------+-----------+
|          1|2024-01-01 00:00:00|   100|       null|
|          1|2024-01-02 00:00:00|   120|        100|
|          1|2024-01-03 00:00:00|   110|        120|
|          2|2024-01-01 00:00:00|   200|       null|
|          2|2024-01-02 00:00:00|   220|        200|
|          2|2024-01-03 00:00:00|   210|        220|
+-----------+-------------------+------+-----------+

The first row in each customer group has null for prev_amount because there’s no previous value.

Calculate day-over-day change by combining lag() to get the previous value and subtracting it from the current value:

# Calculate the day-over-day change in amount
(
    df_ts.withColumn("prev_amount", lag("amount", 1).over(window_spec))
    .withColumn("daily_change", col("amount") - col("prev_amount"))
    .show()
)

Output:

+-----------+-------------------+------+-----------+------------+
|customer_id|               date|amount|prev_amount|daily_change|
+-----------+-------------------+------+-----------+------------+
|          1|2024-01-01 00:00:00|   100|       null|        null|
|          1|2024-01-02 00:00:00|   120|        100|          20|
|          1|2024-01-03 00:00:00|   110|        120|         -10|
|          2|2024-01-01 00:00:00|   200|       null|        null|
|          2|2024-01-02 00:00:00|   220|        200|          20|
|          2|2024-01-03 00:00:00|   210|        220|         -10|
+-----------+-------------------+------+-----------+------------+

Window Analytics

Complex analytics operations like rankings, running totals, and moving averages require window functions that operate within data partitions. These functions enable sophisticated analytical queries without self-joins.

Apply ranking functions within partitioned groups:

  • Combine Window.partitionBy() and Window.orderBy() to rank within groups
  • rank() handles ties by giving them the same rank with gaps (e.g., 1, 2, 2, 4)
  • row_number() always assigns unique sequential numbers (e.g., 1, 2, 3, 4)
  • dense_rank() gives ties the same rank without gaps (e.g., 1, 2, 2, 3)
from pyspark.sql.functions import rank, row_number, dense_rank

# Create sample data with categories to show ranking within groups
ranking_data = spark.createDataFrame(
    [("Math", "Alice", 100), ("Math", "Bob", 150), ("Math", "Charlie", 150),
     ("Science", "Diana", 200), ("Science", "Eve", 100)],
    ["subject", "name", "score"]
)

# Define window partitioned by subject, ordered by score descending
window_spec = Window.partitionBy("subject").orderBy(col("score").desc())

# Calculate different ranking methods within each subject
(
    ranking_data.withColumn("rank", rank().over(window_spec))
    .withColumn("row_number", row_number().over(window_spec))
    .withColumn("dense_rank", dense_rank().over(window_spec))
    .show()
)

Output:

+---+-------+-----+----+----------+----------+
| id|   name|score|rank|row_number|dense_rank|
+---+-------+-----+----+----------+----------+
|  4|  Diana|  200|   1|         1|         1|
|  2|    Bob|  150|   2|         2|         2|
|  3|Charlie|  150|   2|         3|         2|
|  1|  Alice|  100|   4|         4|         3|
|  5|    Eve|  100|   4|         5|         3|
+---+-------+-----+----+----------+----------+

Calculate running totals using rowsBetween() to define a window range:

  • Window.unboundedPreceding starts the window at the first row of the partition
  • Window.currentRow ends the window at the current row being processed
  • This creates an expanding window that includes all rows from the start up to the current position
from pyspark.sql.functions import sum as _sum

# Create daily sales data with store identifier
daily_sales = spark.createDataFrame(
    [("A", 1, 50), ("A", 2, 75), ("A", 3, 100),
     ("B", 1, 25), ("B", 2, 150), ("B", 3, 80)],
    ["store", "day", "sales"]
)

# Define window partitioned by store, from beginning to current row
window_spec = (
    Window.partitionBy("store")
    .orderBy("day")
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)
)

# Calculate running total of sales per store
daily_sales.withColumn("running_total", _sum("sales").over(window_spec)).show()

Output:

+---+-----+-------------+
|day|sales|running_total|
+---+-----+-------------+
|  1|   50|           50|
|  2|   75|          125|
|  3|  100|          225|
|  4|   25|          250|
|  5|  150|          400|
+---+-----+-------------+

Use rowsBetween(-1, 1) to create a 3-row sliding window that includes the previous row, current row, and next row:

# Define window for 3-row moving average (previous, current, next)
window_spec = (
    Window.partitionBy("customer_id")
    .orderBy("date")
    .rowsBetween(-1, 1)
)

# Calculate moving average of amount over the window
df_ts.withColumn("moving_avg", avg("amount").over(window_spec)).show()

Output:

+-----------+-------------------+------+----------+
|customer_id|               date|amount|moving_avg|
+-----------+-------------------+------+----------+
|          1|2024-01-01 00:00:00|   100|     110.0|
|          1|2024-01-02 00:00:00|   120|     110.0|
|          1|2024-01-03 00:00:00|   110|     115.0|
|          2|2024-01-01 00:00:00|   200|     210.0|
|          2|2024-01-02 00:00:00|   220|     210.0|
|          2|2024-01-03 00:00:00|   210|     215.0|
+-----------+-------------------+------+----------+

Join Operations

Combining data from multiple tables is a core operation in data analysis. PySpark supports various join types including inner, left, and broadcast joins, with automatic optimization for performance.

Create sample tables for joining:

# Create sample customers and orders tables for joining
customers = spark.createDataFrame(
    [(1, "Alice", "US"), (2, "Bob", "UK"), (3, "Charlie", "US")],
    ["customer_id", "name", "country"]
)

orders = spark.createDataFrame(
    [(101, 1, 100), (102, 2, 150), (103, 1, 200), (104, 3, 120)],
    ["order_id", "customer_id", "amount"]
)

Use join() to perform an inner join, which returns only rows with matching keys in both DataFrames:

# Join customers and orders on customer_id
customers.join(orders, "customer_id").show()

Output:

+-----------+-------+-------+--------+------+
|customer_id|   name|country|order_id|amount|
+-----------+-------+-------+--------+------+
|          1|  Alice|     US|     101|   100|
|          1|  Alice|     US|     103|   200|
|          2|    Bob|     UK|     102|   150|
|          3|Charlie|     US|     104|   120|
+-----------+-------+-------+--------+------+

Perform a left join by specifying "left" as the third argument, which retains all left table rows regardless of matches:

# Create extended customers table including Diana
customers_extended = spark.createDataFrame(
    [(1, "Alice", "US"), (2, "Bob", "UK"), (3, "Charlie", "US"), (4, "Diana", "CA")],
    ["customer_id", "name", "country"]
)

# Left join to keep all customers even without orders
customers_extended.join(orders, "customer_id", "left").show()

Output:

+-----------+-------+-------+--------+------+
|customer_id|   name|country|order_id|amount|
+-----------+-------+-------+--------+------+
|          1|  Alice|     US|     103|   200|
|          1|  Alice|     US|     101|   100|
|          2|    Bob|     UK|     102|   150|
|          3|Charlie|     US|     104|   120|
|          4|  Diana|     CA|    NULL|  NULL|
+-----------+-------+-------+--------+------+

Chain multiple join() calls together to combine three or more DataFrames in sequence:

# Create products and order_items tables
products = spark.createDataFrame(
    [(1, "Widget"), (2, "Gadget")],
    ["product_id", "product_name"]
)

order_items = spark.createDataFrame(
    [(101, 1), (102, 2), (103, 1), (104, 2)],
    ["order_id", "product_id"]
)

# Chain multiple joins to combine orders, items, and products
(
    orders.join(order_items, "order_id")
    .join(products, "product_id")
    .select("order_id", "customer_id", "product_name", "amount")
    .show()
)

Output:

+--------+-----------+------------+------+
|order_id|customer_id|product_name|amount|
+--------+-----------+------------+------+
|     103|          1|      Widget|   200|
|     101|          1|      Widget|   100|
|     104|          3|      Gadget|   120|
|     102|          2|      Gadget|   150|
+--------+-----------+------------+------+

SQL Integration

PySpark supports standard SQL syntax for querying data. You can write SQL queries using familiar SELECT, JOIN, and WHERE clauses alongside PySpark operations.

Use createOrReplaceTempView() to register a DataFrame as a temporary SQL table, allowing it to be queried multiple times with SQL syntax:

# Register DataFrame as a temporary SQL view named customers
df.createOrReplaceTempView("customers")

Execute SQL queries on DataFrames registered with createOrReplaceTempView() using spark.sql():

# Execute SQL query to aggregate revenue by region
spark.sql("""
    SELECT region, SUM(amount) as total_revenue
    FROM customers
    GROUP BY region
    ORDER BY total_revenue DESC
""").show()

Output:

+------+-------------+
|region|total_revenue|
+------+-------------+
| North|          300|
|  West|          180|
| South|          150|
|  East|          120|
+------+-------------+

Chain SQL queries with DataFrame operations by storing spark.sql() results and applying methods like filter() and orderBy():

# Query customers with amount greater than 100
result = spark.sql("""
    SELECT customer_id, name, amount
    FROM customers
    WHERE amount > 100
""")

# Filter SQL results using DataFrame API and sort by amount
(
    result.filter(col("region") != "South")
    .orderBy("amount", ascending=False)
    .show()
)

Output:

+-----------+-------+------+
|customer_id|   name|amount|
+-----------+-------+------+
|          3|Charlie|   200|
|          5|    Eve|   180|
|          4|  Diana|   120|
+-----------+-------+------+

Use SQL syntax within spark.sql() for complex joins and aggregations when SQL is more readable than DataFrame API:

# Register orders and customers as temporary SQL views
orders.createOrReplaceTempView("orders")
customers.createOrReplaceTempView("customers_table")

# Join and aggregate using SQL syntax
spark.sql("""
    SELECT
        c.name,
        COUNT(o.order_id) as order_count,
        SUM(o.amount) as total_spent
    FROM customers_table c
    JOIN orders o ON c.customer_id = o.customer_id
    GROUP BY c.name
    ORDER BY total_spent DESC
""").show()

Output:

+-------+-----------+-----------+
|   name|order_count|total_spent|
+-------+-----------+-----------+
|  Alice|          2|        300|
|    Bob|          1|        150|
|Charlie|          1|        120|
+-------+-----------+-----------+

Custom Functions

When built-in functions aren’t sufficient, custom logic can be implemented using pandas UDFs. These user-defined functions provide vectorized performance through Apache Arrow and support both scalar operations and grouped transformations.

📚 For taking your data science projects from prototype to production, check out Production-Ready Data Science.

Create a scalar pandas UDF with the @pandas_udf decorator to apply custom Python functions to columns with vectorized performance:

from pyspark.sql.functions import pandas_udf
import pandas as pd

# Define pandas UDF to calculate discount based on price and quantity
@pandas_udf("double")
def calculate_discount(amount: pd.Series, quantity: pd.Series) -> pd.Series:
    return amount * quantity * 0.1

# Create sample order data with price and quantity
order_data = spark.createDataFrame(
    [(1, 100.0, 2), (2, 150.0, 3), (3, 200.0, 1)],
    ["order_id", "price", "quantity"]
)

# Apply discount calculation UDF to create discount column
order_data.withColumn(
    "discount",
    calculate_discount(col("price"), col("quantity"))
).show()

Output:

+--------+-----+--------+--------+
|order_id|price|quantity|discount|
+--------+-----+--------+--------+
|       1|100.0|       2|    20.0|
|       2|150.0|       3|    45.0|
|       3|200.0|       1|    20.0|
+--------+-----+--------+--------+

Apply custom pandas functions to grouped data with groupBy().applyInPandas():

  • Define a function that transforms each group as a pandas DataFrame
  • Specify output schema to tell PySpark the resulting column names and types
# Define function to normalize amounts within each group
def normalize_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    pdf["normalized"] = (pdf["amount"] - pdf["amount"].mean()) / pdf["amount"].std()
    return pdf

schema = "customer_id long, date timestamp, amount long, normalized double"

# Apply normalization function to each customer_id group
df_ts.groupBy("customer_id").applyInPandas(normalize_by_group, schema).show()

Output:

+-----------+-------------------+------+----------+
|customer_id|               date|amount|normalized|
+-----------+-------------------+------+----------+
|          1|2024-01-01 00:00:00|   100|      -1.0|
|          1|2024-01-02 00:00:00|   120|       1.0|
|          1|2024-01-03 00:00:00|   110|       0.0|
|          2|2024-01-01 00:00:00|   200|      -1.0|
|          2|2024-01-02 00:00:00|   220|       1.0|
|          2|2024-01-03 00:00:00|   210|       0.0|
+-----------+-------------------+------+----------+

SQL Expressions

SQL expressions can be embedded directly within DataFrame operations for complex transformations. The expr() and selectExpr() functions allow SQL syntax to be used alongside DataFrame methods, providing flexibility in query construction.

Use the expr() function to embed SQL syntax within DataFrame operations, allowing SQL-style calculations in withColumn():

from pyspark.sql.functions import expr

# Add tax and total columns using SQL expressions
(
    df.withColumn("tax", expr("amount * 0.1"))
    .withColumn("total", expr("amount + (amount * 0.1)"))
    .show()
)

Output:

+-----------+-------+------+------+----+-----+
|customer_id|   name|region|amount| tax|total|
+-----------+-------+------+------+----+-----+
|          1|  Alice| North|   100|10.0|110.0|
|          2|    Bob| South|   150|15.0|165.0|
|          3|Charlie| North|   200|20.0|220.0|
|          4|  Diana|  East|   120|12.0|132.0|
|          5|    Eve|  West|   180|18.0|198.0|
+-----------+-------+------+------+----+-----+

Unlike select() which uses column objects and method chaining, selectExpr() accepts SQL strings and is preferred for complex expressions that are simpler to write as SQL:

# Select columns with calculations and CASE statement using SQL syntax
df.selectExpr(
    "customer_id",
    "name",
    "amount * 1.1 AS amount_with_tax",
    "CASE WHEN amount > 150 THEN 'high' ELSE 'normal' END as category"
).show()

Output:

+-----------+-------+---------------+--------+
|customer_id|   name|amount_with_tax|category|
+-----------+-------+---------------+--------+
|          1|  Alice|          110.0|  normal|
|          2|    Bob|          165.0|  normal|
|          3|Charlie|          220.0|    high|
|          4|  Diana|          132.0|  normal|
|          5|    Eve|          198.0|    high|
+-----------+-------+---------------+--------+

Conclusion

You’ve learned PySpark SQL from fundamentals to advanced analytics. Here’s a quick reference:

Category Methods Description
DataFrame Creation createDataFrame(), read.csv(), read.parquet() Create DataFrames from Python data, CSV, or Parquet files
Data Exploration show(), describe(), count(), columns, distinct(), sample() Inspect and explore DataFrame contents
Selection & Filtering select(), filter(), where(), drop() Choose columns and filter rows
Column Operations withColumn(), withColumnRenamed(), cast(), col(), lit() Add, modify, rename columns and change types
Aggregations groupBy(), agg(), sum(), avg(), count(), max(), min(), collect_list() Group data and compute statistics
String Functions concat(), split(), regexp_replace() Manipulate text data
Date/Time Functions year(), month(), dayofmonth(), date_format(), to_timestamp() Extract and format date components
Window Functions Window.partitionBy(), rank(), row_number(), dense_rank(), lag(), rowsBetween() Rankings, time series, and running calculations
Joins join() Combine DataFrames with inner, left, right joins
SQL Integration createOrReplaceTempView(), spark.sql() Execute SQL queries on DataFrames
Custom Functions @pandas_udf, applyInPandas() Vectorized custom functions and group operations
SQL Expressions expr(), selectExpr() Embed SQL syntax in DataFrame operations

PySpark SQL gives you the power to scale from prototyping to production without rewriting your workflows, making it an essential tool for modern data science.

Next Steps:

Expand your PySpark expertise with these related tutorials:

Leave a Comment

Your email address will not be published. Required fields are marked *

0
    0
    Your Cart
    Your cart is empty
    Scroll to Top

    Work with Khuyen Tran

    Work with Khuyen Tran