SQL Indexο
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from llama_index import SimpleDirectoryReader, WikipediaReader
from IPython.display import Markdown, display
Load Wikipedia Dataο
# install wikipedia python package
!pip install wikipedia
wiki_docs = WikipediaReader().load_data(pages=['Toronto', 'Berlin', 'Tokyo'])
Create Database Schemaο
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
Build Indexο
from llama_index import GPTSQLStructStoreIndex, SQLDatabase, ServiceContext
from langchain import OpenAI
from llama_index import LLMPredictor
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="text-davinci-002"))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_database.table_info
"Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))."
# NOTE: the table_name specified here is the table that you
# want to extract into from unstructured documents.
index = GPTSQLStructStoreIndex.from_documents(
wiki_docs,
sql_database=sql_database,
table_name="city_stats",
service_context=service_context
)
# view current table
stmt = select(
city_stats_table.c["city_name", "population", "country"]
).select_from(city_stats_table)
with engine.connect() as connection:
results = connection.execute(stmt).fetchall()
print(results)
[('Toronto', 2731571, 'Canada'), ('Berlin', 600000, 'Germany'), ('Tokyo', 13929286, 'Japan')]
Query Indexο
We first show a raw SQL query, which directly executes over the table
query_engine = index.as_query_engine(
query_mode="sql"
)
response = query_engine.query("SELECT city_name from city_stats")
> [query] Total LLM token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens
display(Markdown(f"<b>{response}</b>"))
[(βBerlinβ,), (βTokyoβ,), (βTorontoβ,)]
Here we show a natural language query, which is translated to a SQL query under the hood
# set Logging to DEBUG for more detailed outputs
query_engine = index.as_query_engine(
query_mode="nl"
)
response = query_engine.query("Which city has the highest population?")
> Predicted SQL query: SELECT city_name, population
FROM city_stats
ORDER BY population DESC
LIMIT 1
> [query] Total LLM token usage: 144 tokens
> [query] Total embedding token usage: 0 tokens
display(Markdown(f"<b>{response}</b>"))
[(βTokyoβ, 13929286)]
# you can also fetch the raw result from SQLAlchemy!
response.extra_info["result"]
[('Tokyo', 13929286)]
Using Langchain for Queryingο
Since our SQLDatabase inherits from langchain, you can also use langchain itself for querying purposes.
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
llm = OpenAI(temperature=0)
# set Logging to DEBUG for more detailed outputs
db_chain = SQLDatabaseChain(llm=llm, database=sql_database)
db_chain.run("Which city has the highest population?")
> Entering new SQLDatabaseChain chain...
Which city has the highest population?
SQLQuery: SELECT city_name FROM city_stats ORDER BY population DESC LIMIT 1;
SQLResult: [('Tokyo',)]
Answer: Tokyo has the highest population.
> Finished chain.
' Tokyo has the highest population.'