Open In Colab

SQL Query Engine with LlamaIndex + DuckDB#

This guide showcases the core LlamaIndex SQL capabilities with DuckDB.

We go through some core LlamaIndex data structures, including the NLSQLTableQueryEngine and SQLTableRetrieverQueryEngine.

If you’re opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.

%pip install llama-index-readers-wikipedia
!pip install llama-index
!pip install duckdb duckdb-engine
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from llama_index.core import SQLDatabase, SimpleDirectoryReader, Document
from llama_index.readers.wikipedia import WikipediaReader
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from IPython.display import Markdown, display

Basic Text-to-SQL with our NLSQLTableQueryEngine#

In this initial example, we walk through populating a SQL database with some test datapoints, and querying it with our text-to-SQL capabilities.

Create Database Schema + Test Data#

We use sqlalchemy, a popular SQL database toolkit, to connect to DuckDB and create an empty city_stats Table. We then populate it with some test data.

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    column,
)
engine = create_engine("duckdb:///:memory:")
# uncomment to make this work with MotherDuck
# engine = create_engine("duckdb:///md:llama-index")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)

metadata_obj.create_all(engine)
# print tables
metadata_obj.tables.keys()
dict_keys(['city_stats'])

We introduce some test data into the city_stats table

from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
    print(cursor.fetchall())
[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]

Create SQLDatabase Object#

We first define our SQLDatabase abstraction (a light wrapper around SQLAlchemy).

from llama_index.core import SQLDatabase
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/duckdb_engine/__init__.py:162: DuckDBEngineWarning: duckdb-engine doesn't yet support reflection on indices
  warnings.warn(

Query Index#

Here we demonstrate the capabilities of NLSQLTableQueryEngine, which performs text-to-SQL.

  1. We construct a NLSQLTableQueryEngine and pass in our SQL database object.

  2. We run queries against the query engine.

query_engine = NLSQLTableQueryEngine(sql_database)
response = query_engine.query("Which city has the highest population?")
INFO:llama_index.indices.struct_store.sql_query:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .
> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .
/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/langchain/sql_database.py:238: UserWarning: This method is deprecated - please use `get_usable_table_names`.
  warnings.warn(
INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 332 tokens
> [query] Total LLM token usage: 332 tokens
INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens
str(response)
' Tokyo has the highest population, with 13,960,000 people.'
response.metadata
{'result': [('Tokyo', 13960000)],
 'sql_query': 'SELECT city_name, population \nFROM city_stats \nORDER BY population DESC \nLIMIT 1;'}

Advanced Text-to-SQL with our SQLTableRetrieverQueryEngine#

In this guide, we tackle the setting where you have a large number of tables in your database, and putting all the table schemas into the prompt may overflow the text-to-SQL prompt.

We first index the schemas with our ObjectIndex, and then use our SQLTableRetrieverQueryEngine abstraction on top.

engine = create_engine("duckdb:///:memory:")
# uncomment to make this work with MotherDuck
# engine = create_engine("duckdb:///md:llama-index")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
all_table_names = ["city_stats"]
# create a ton of dummy tables
n = 100
for i in range(n):
    tmp_table_name = f"tmp_table_{i}"
    tmp_table = Table(
        tmp_table_name,
        metadata_obj,
        Column(f"tmp_field_{i}_1", String(16), primary_key=True),
        Column(f"tmp_field_{i}_2", Integer),
        Column(f"tmp_field_{i}_3", String(16), nullable=False),
    )
    all_table_names.append(f"tmp_table_{i}")

metadata_obj.create_all(engine)
# insert dummy data
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

Construct Object Index#

from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex
table_node_mapping = SQLTableNodeMapping(sql_database)

table_schema_objs = []
for table_name in all_table_names:
    table_schema_objs.append(SQLTableSchema(table_name=table_name))

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens
> [build_index_from_nodes] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 6343 tokens
> [build_index_from_nodes] Total embedding token usage: 6343 tokens

Query Index with SQLTableRetrieverQueryEngine#

query_engine = SQLTableRetrieverQueryEngine(
    sql_database,
    obj_index.as_retriever(similarity_top_k=1),
)
response = query_engine.query("Which city has the highest population?")
INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens
> [retrieve] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 7 tokens
> [retrieve] Total embedding token usage: 7 tokens
INFO:llama_index.indices.struct_store.sql_query:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .
> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .
INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 337 tokens
> [query] Total LLM token usage: 337 tokens
INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens
response
Response(response=' The city with the highest population is Tokyo, with a population of 13,960,000.', source_nodes=[], metadata={'result': [('Tokyo', 13960000)], 'sql_query': 'SELECT city_name, population \nFROM city_stats \nORDER BY population DESC \nLIMIT 1;'})