SQL Index

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from llama_index import SimpleDirectoryReader, WikipediaReader
from IPython.display import Markdown, display

Load Wikipedia Data

# install wikipedia python package
!pip install wikipedia
wiki_docs = WikipediaReader().load_data(pages=['Toronto', 'Berlin', 'Tokyo'])

Create Database Schema

from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

Build Index

from llama_index import GPTSQLStructStoreIndex, SQLDatabase, ServiceContext
from langchain import OpenAI
from llama_index import LLMPredictor
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="text-davinci-002"))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_database.table_info
"Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))."
# NOTE: the table_name specified here is the table that you
# want to extract into from unstructured documents.
index = GPTSQLStructStoreIndex.from_documents(
    wiki_docs, 
    sql_database=sql_database, 
    table_name="city_stats",
    service_context=service_context
)
# view current table
stmt = select(
    city_stats_table.c["city_name", "population", "country"]
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)
[('Toronto', 2731571, 'Canada'), ('Berlin', 600000, 'Germany'), ('Tokyo', 13929286, 'Japan')]

Query Index

We first show a raw SQL query, which directly executes over the table

query_engine = index.as_query_engine(
    query_mode="sql"
)
response = query_engine.query("SELECT city_name from city_stats")
> [query] Total LLM token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens
display(Markdown(f"<b>{response}</b>"))

[(β€˜Berlin’,), (β€˜Tokyo’,), (β€˜Toronto’,)]

Here we show a natural language query, which is translated to a SQL query under the hood

# set Logging to DEBUG for more detailed outputs
query_engine = index.as_query_engine(
    query_mode="nl"
)
response = query_engine.query("Which city has the highest population?")
> Predicted SQL query: SELECT city_name, population
FROM city_stats
ORDER BY population DESC
LIMIT 1
> [query] Total LLM token usage: 144 tokens
> [query] Total embedding token usage: 0 tokens
display(Markdown(f"<b>{response}</b>"))

[(β€˜Tokyo’, 13929286)]

# you can also fetch the raw result from SQLAlchemy! 
response.extra_info["result"]
[('Tokyo', 13929286)]

Using Langchain for Querying

Since our SQLDatabase inherits from langchain, you can also use langchain itself for querying purposes.

from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
llm = OpenAI(temperature=0)
# set Logging to DEBUG for more detailed outputs
db_chain = SQLDatabaseChain(llm=llm, database=sql_database)
db_chain.run("Which city has the highest population?")
> Entering new SQLDatabaseChain chain...
Which city has the highest population? 
SQLQuery: SELECT city_name FROM city_stats ORDER BY population DESC LIMIT 1;
SQLResult: [('Tokyo',)]
Answer: Tokyo has the highest population.
> Finished chain.
' Tokyo has the highest population.'