SQL Index Guide (Core)

This is a basic guide to LlamaIndex’s SQL index capabilities. We first show how to define a SQL table, then we build a TableIndex over the schema. This will allow us to synthesize a SQL query given the user’s natural language query.

import os
import openai

os.environ["OPENAI_API_KEY"] = "sk-..."
openai.api_key = os.environ["OPENAI_API_KEY"]
# import logging
# import sys

# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from IPython.display import Markdown, display

Create Database Schema

We use sqlalchemy, a popular SQL database toolkit, to create an empty city_stats Table

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    column,
)
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

Define SQL Database

We first define our SQLDatabase abstraction (a light wrapper around SQLAlchemy).

from llama_index import SQLDatabase, ServiceContext
from llama_index.llms import OpenAI
llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")
service_context = ServiceContext.from_defaults(llm=llm)
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_database.table_info
'\nCREATE TABLE city_stats (\n\tcity_name VARCHAR(16) NOT NULL, \n\tpopulation INTEGER, \n\tcountry VARCHAR(16) NOT NULL, \n\tPRIMARY KEY (city_name)\n)\n\n/*\n3 rows from city_stats table:\ncity_name\tpopulation\tcountry\n\n*/'

We add some testing data to our SQL database.

sql_database = SQLDatabase(engine, include_tables=["city_stats"])
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()
# view current table
stmt = select(city_stats_table.c["city_name", "population", "country"]).select_from(
    city_stats_table
)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)
[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]

Query Index

We first show how we can execute a raw SQL query, which directly executes over the table.

from sqlalchemy import text

with engine.connect() as con:
    rows = con.execute(text("SELECT city_name from city_stats"))
    for row in rows:
        print(row)
('Chicago',)
('Seoul',)
('Tokyo',)
('Toronto',)

Natural language SQL

Once we have constructed our SQL database, we can use the NLSQLTableQueryEngine to construct natural language queries that are synthesized into SQL queries.

Note that we need to specify the tables we want to use with this query engine. If we don’t the query engine will pull all the schema context, which could overflow the context window of the LLM.

from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["city_stats"],
)
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

The city with the highest population is Tokyo, with a population of 13,960,000.

This query engine should used in any case where you can specify the tables you want to query over beforehand, or the total size of all the table schema plus the rest of the prompt fits your context window.

Building our Table Index

If we don’t know ahead of time which table we would like to use, and the total size of the table schema overflows your context window size, we should store the table schema in an index so that during query time we can retrieve the right schema.

The way we can do this is using the SQLTableNodeMapping object, which takes in a SQLDatabase and produces a Node object for each SQLTableSchema object passed into the ObjectIndex constructor.

from llama_index.indices.struct_store.sql_query import SQLTableRetrieverQueryEngine
from llama_index.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index import VectorStoreIndex

# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

Now we can take our SQLTableRetrieverQueryEngine and query it for our response.

response = query_engine.query("Which city has the highest population?")
display(Markdown(f"<b>{response}</b>"))

The city with the highest population is Tokyo, with a population of 13,960,000.

# you can also fetch the raw result from SQLAlchemy!
response.metadata["result"]
[('Tokyo', 13960000)]

You can also add additional context information for each table schema you define.

# manually set context text
city_stats_text = (
    "This table gives information regarding the population and country of a given city.\n"
    "The user will query with codewords, where 'foo' corresponds to population and 'bar'"
    "corresponds to city."
)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]

Using LangChain for Querying

Since our SQLDatabase inherits from langchain, you can also use langchain itself for querying purposes.

# Needs langchain_experimental package:
from langchain_experimental.sql import SQLDatabaseChain
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
llm = OpenAI(temperature=0)
# db = SQLDatabase(...) # to fill in if not using llama_index's SQLDatabase
# set Logging to DEBUG for more detailed outputs
db_chain = SQLDatabaseChain.from_llm(llm=llm, db=sql_database)
db_chain.run("Which city has the highest population?")
'Tokyo has the highest population with 13960000 people.'