Have you ever wanted to ask a question about your PySpark data in plain English instead of writing SQL?
LangChain’s Spark SQL Toolkit enables natural language data querying by:
- Translating your requests into SQL
- Executing them against your Spark cluster
- Returning the results in a readable format
This makes it much easier to work with large-scale data while still leveraging Spark’s powerful distributed computing capabilities.
To demonstrate, we’ll create a simple DataFrame and use LangChain’s Spark SQL tool to query it.
from pyspark.sql import SparkSession, Row
# Create sample data and DataFrame
data = [Row(name="Alice", age=30), Row(name="Bob", age=25)]
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(data)
df.write.saveAsTable("people")
df.show()
This creates a table people
accessible via SQL.
Next, we’ll set up the key components that enable natural language querying of our Spark data. Here are the steps:
- Initialize the Spark SQL tool which provides the interface to our Spark database.
- Initialize a language model.
- Initialize the Spark SQL toolkit, which connects the language model with the Spark database.
- Create an agent executor that combines a language model with the Spark SQL toolkit.
# Initialize Spark SQL tool
spark_sql = SparkSQL(schema="default")
# Initialize LLM
llm = ChatOpenAI(temperature=0)
# Initialize toolkit
toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)
# Create agent executor
agent_executor = create_spark_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
For a hands-on guide on how to build coordinated AI agents with LangGraph, check out Building Coordinated AI Agents with LangGraph: A Hands-On Tutorial.
Now we can ask the agent to query the data.
agent_executor.run("What is the average age of people in the table?")
> Entering new AgentExecutor chain...
Action: list_tables_sql_db
Action Input:
Observation: people
Thought:I can query the "people" table for the average age.
Action: query_sql_db
Action Input: SELECT AVG(age) FROM people
Observation: [('27.5',)]
Thought:The average age of people in the table is 27.5.
Final Answer: 27.5
> Finished chain.
The answer for the average age is correct.
The output shows that the agent:
- Looked up the available tables
- Queried the
people
table for the average age - Got the result
- Answered the question with the result
Let’s try another question.
agent_executor.run("Who is the oldest person in the table?")
> Entering new AgentExecutor chain...
Action: list_tables_sql_db
Action Input:
Observation: people
Thought:I should query the "people" table to find the oldest person.
Action: schema_sql_db
Action Input: people
Observation: CREATE TABLE spark_catalog.default.people (
name STRING,
age BIGINT)
;
/*
3 rows from people table:
name age
Alice 30
Bob 25
*/
Thought:I should write a query to select the oldest person from the "people" table.
Action: query_sql_db
Action Input: SELECT name, age FROM people ORDER BY age DESC LIMIT 1
Observation: [('Alice', '30')]
Thought:I now know the final answer
Final Answer: Alice
> Finished chain.
The answer for the oldest person is also correct.