Skip to content

Sql

NLSQLRetriever #

Bases: BaseRetriever, PromptMixin

Text-to-SQL Retriever.

Retrieves via text.

Parameters:

Name Type Description Default
sql_database SQLDatabase

SQL database.

required
text_to_sql_prompt BasePromptTemplate

Prompt template for text-to-sql. Defaults to DEFAULT_TEXT_TO_SQL_PROMPT.

None
context_query_kwargs dict

Mapping from table name to context query. Defaults to None.

None
tables Union[List[str], List[Table]]

List of table names or Table objects.

None
table_retriever ObjectRetriever[SQLTableSchema]

Object retriever for SQLTableSchema objects. Defaults to None.

None
context_str_prefix str

Prefix for context string. Defaults to None.

None
service_context ServiceContext

Service context. Defaults to None.

None
return_raw bool

Whether to return plain-text dump of SQL results, or parsed into Nodes.

True
handle_sql_errors bool

Whether to handle SQL errors. Defaults to True.

True
sql_only bool)

Whether to get only sql and not the sql query result. Default to False.

False
llm Optional[LLM]

Language model to use.

None
Source code in llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
class NLSQLRetriever(BaseRetriever, PromptMixin):
    """Text-to-SQL Retriever.

    Retrieves via text.

    Args:
        sql_database (SQLDatabase): SQL database.
        text_to_sql_prompt (BasePromptTemplate): Prompt template for text-to-sql.
            Defaults to DEFAULT_TEXT_TO_SQL_PROMPT.
        context_query_kwargs (dict): Mapping from table name to context query.
            Defaults to None.
        tables (Union[List[str], List[Table]]): List of table names or Table objects.
        table_retriever (ObjectRetriever[SQLTableSchema]): Object retriever for
            SQLTableSchema objects. Defaults to None.
        context_str_prefix (str): Prefix for context string. Defaults to None.
        service_context (ServiceContext): Service context. Defaults to None.
        return_raw (bool): Whether to return plain-text dump of SQL results, or parsed into Nodes.
        handle_sql_errors (bool): Whether to handle SQL errors. Defaults to True.
        sql_only (bool) : Whether to get only sql and not the sql query result.
            Default to False.
        llm (Optional[LLM]): Language model to use.

    """

    def __init__(
        self,
        sql_database: SQLDatabase,
        text_to_sql_prompt: Optional[BasePromptTemplate] = None,
        context_query_kwargs: Optional[dict] = None,
        tables: Optional[Union[List[str], List[Table]]] = None,
        table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
        context_str_prefix: Optional[str] = None,
        sql_parser_mode: SQLParserMode = SQLParserMode.DEFAULT,
        llm: Optional[LLM] = None,
        embed_model: Optional[BaseEmbedding] = None,
        service_context: Optional[ServiceContext] = None,
        return_raw: bool = True,
        handle_sql_errors: bool = True,
        sql_only: bool = False,
        callback_manager: Optional[CallbackManager] = None,
        verbose: bool = False,
        **kwargs: Any,
    ) -> None:
        """Initialize params."""
        self._sql_retriever = SQLRetriever(sql_database, return_raw=return_raw)
        self._sql_database = sql_database
        self._get_tables = self._load_get_tables_fn(
            sql_database, tables, context_query_kwargs, table_retriever
        )
        self._context_str_prefix = context_str_prefix
        self._llm = llm or llm_from_settings_or_context(Settings, service_context)
        self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT
        self._sql_parser_mode = sql_parser_mode

        embed_model = embed_model or embed_model_from_settings_or_context(
            Settings, service_context
        )
        self._sql_parser = self._load_sql_parser(sql_parser_mode, embed_model)
        self._handle_sql_errors = handle_sql_errors
        self._sql_only = sql_only
        self._verbose = verbose
        super().__init__(
            callback_manager=callback_manager
            or callback_manager_from_settings_or_context(Settings, service_context)
        )

    def _get_prompts(self) -> Dict[str, Any]:
        """Get prompts."""
        return {
            "text_to_sql_prompt": self._text_to_sql_prompt,
        }

    def _update_prompts(self, prompts: PromptDictType) -> None:
        """Update prompts."""
        if "text_to_sql_prompt" in prompts:
            self._text_to_sql_prompt = prompts["text_to_sql_prompt"]

    def _get_prompt_modules(self) -> PromptMixinType:
        """Get prompt modules."""
        return {}

    def _load_sql_parser(
        self, sql_parser_mode: SQLParserMode, embed_model: BaseEmbedding
    ) -> BaseSQLParser:
        """Load SQL parser."""
        if sql_parser_mode == SQLParserMode.DEFAULT:
            return DefaultSQLParser()
        elif sql_parser_mode == SQLParserMode.PGVECTOR:
            return PGVectorSQLParser(embed_model=embed_model)
        else:
            raise ValueError(f"Unknown SQL parser mode: {sql_parser_mode}")

    def _load_get_tables_fn(
        self,
        sql_database: SQLDatabase,
        tables: Optional[Union[List[str], List[Table]]] = None,
        context_query_kwargs: Optional[dict] = None,
        table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
    ) -> Callable[[str], List[SQLTableSchema]]:
        """Load get_tables function."""
        context_query_kwargs = context_query_kwargs or {}
        if table_retriever is not None:
            return lambda query_str: cast(Any, table_retriever).retrieve(query_str)
        else:
            if tables is not None:
                table_names: List[str] = [
                    t.name if isinstance(t, Table) else t for t in tables
                ]
            else:
                table_names = list(sql_database.get_usable_table_names())
            context_strs = [context_query_kwargs.get(t, None) for t in table_names]
            table_schemas = [
                SQLTableSchema(table_name=t, context_str=c)
                for t, c in zip(table_names, context_strs)
            ]
            return lambda _: table_schemas

    def retrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """Retrieve with metadata."""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        table_desc_str = self._get_table_context(query_bundle)
        logger.info(f"> Table desc str: {table_desc_str}")
        if self._verbose:
            print(f"> Table desc str: {table_desc_str}")

        response_str = self._llm.predict(
            self._text_to_sql_prompt,
            query_str=query_bundle.query_str,
            schema=table_desc_str,
            dialect=self._sql_database.dialect,
        )

        sql_query_str = self._sql_parser.parse_response_to_sql(
            response_str, query_bundle
        )
        # assume that it's a valid SQL query
        logger.debug(f"> Predicted SQL query: {sql_query_str}")
        if self._verbose:
            print(f"> Predicted SQL query: {sql_query_str}")

        if self._sql_only:
            sql_only_node = TextNode(text=f"{sql_query_str}")
            retrieved_nodes = [NodeWithScore(node=sql_only_node)]
            metadata = {"result": sql_query_str}
        else:
            try:
                retrieved_nodes, metadata = self._sql_retriever.retrieve_with_metadata(
                    sql_query_str
                )
            except BaseException as e:
                # if handle_sql_errors is True, then return error message
                if self._handle_sql_errors:
                    err_node = TextNode(text=f"Error: {e!s}")
                    retrieved_nodes = [NodeWithScore(node=err_node)]
                    metadata = {}
                else:
                    raise

        return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

    async def aretrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """Async retrieve with metadata."""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        table_desc_str = self._get_table_context(query_bundle)
        logger.info(f"> Table desc str: {table_desc_str}")

        response_str = await self._llm.apredict(
            self._text_to_sql_prompt,
            query_str=query_bundle.query_str,
            schema=table_desc_str,
            dialect=self._sql_database.dialect,
        )

        sql_query_str = self._sql_parser.parse_response_to_sql(
            response_str, query_bundle
        )
        # assume that it's a valid SQL query
        logger.debug(f"> Predicted SQL query: {sql_query_str}")

        if self._sql_only:
            sql_only_node = TextNode(text=f"{sql_query_str}")
            retrieved_nodes = [NodeWithScore(node=sql_only_node)]
            metadata: Dict[str, Any] = {}
        else:
            try:
                (
                    retrieved_nodes,
                    metadata,
                ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
            except BaseException as e:
                # if handle_sql_errors is True, then return error message
                if self._handle_sql_errors:
                    err_node = TextNode(text=f"Error: {e!s}")
                    retrieved_nodes = [NodeWithScore(node=err_node)]
                    metadata = {}
                else:
                    raise
        return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle)
        return retrieved_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Async retrieve nodes given query."""
        retrieved_nodes, _ = await self.aretrieve_with_metadata(query_bundle)
        return retrieved_nodes

    def _get_table_context(self, query_bundle: QueryBundle) -> str:
        """Get table context.

        Get tables schema + optional context as a single string.

        """
        table_schema_objs = self._get_tables(query_bundle.query_str)
        context_strs = []
        if self._context_str_prefix is not None:
            context_strs = [self._context_str_prefix]

        for table_schema_obj in table_schema_objs:
            table_info = self._sql_database.get_single_table_info(
                table_schema_obj.table_name
            )

            if table_schema_obj.context_str:
                table_opt_context = " The table description is: "
                table_opt_context += table_schema_obj.context_str
                table_info += table_opt_context

            context_strs.append(table_info)

        return "\n\n".join(context_strs)

retrieve_with_metadata #

retrieve_with_metadata(str_or_query_bundle: QueryType) -> Tuple[List[NodeWithScore], Dict]

Retrieve with metadata.

Source code in llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def retrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """Retrieve with metadata."""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    table_desc_str = self._get_table_context(query_bundle)
    logger.info(f"> Table desc str: {table_desc_str}")
    if self._verbose:
        print(f"> Table desc str: {table_desc_str}")

    response_str = self._llm.predict(
        self._text_to_sql_prompt,
        query_str=query_bundle.query_str,
        schema=table_desc_str,
        dialect=self._sql_database.dialect,
    )

    sql_query_str = self._sql_parser.parse_response_to_sql(
        response_str, query_bundle
    )
    # assume that it's a valid SQL query
    logger.debug(f"> Predicted SQL query: {sql_query_str}")
    if self._verbose:
        print(f"> Predicted SQL query: {sql_query_str}")

    if self._sql_only:
        sql_only_node = TextNode(text=f"{sql_query_str}")
        retrieved_nodes = [NodeWithScore(node=sql_only_node)]
        metadata = {"result": sql_query_str}
    else:
        try:
            retrieved_nodes, metadata = self._sql_retriever.retrieve_with_metadata(
                sql_query_str
            )
        except BaseException as e:
            # if handle_sql_errors is True, then return error message
            if self._handle_sql_errors:
                err_node = TextNode(text=f"Error: {e!s}")
                retrieved_nodes = [NodeWithScore(node=err_node)]
                metadata = {}
            else:
                raise

    return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

aretrieve_with_metadata async #

aretrieve_with_metadata(str_or_query_bundle: QueryType) -> Tuple[List[NodeWithScore], Dict]

Async retrieve with metadata.

Source code in llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
async def aretrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """Async retrieve with metadata."""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    table_desc_str = self._get_table_context(query_bundle)
    logger.info(f"> Table desc str: {table_desc_str}")

    response_str = await self._llm.apredict(
        self._text_to_sql_prompt,
        query_str=query_bundle.query_str,
        schema=table_desc_str,
        dialect=self._sql_database.dialect,
    )

    sql_query_str = self._sql_parser.parse_response_to_sql(
        response_str, query_bundle
    )
    # assume that it's a valid SQL query
    logger.debug(f"> Predicted SQL query: {sql_query_str}")

    if self._sql_only:
        sql_only_node = TextNode(text=f"{sql_query_str}")
        retrieved_nodes = [NodeWithScore(node=sql_only_node)]
        metadata: Dict[str, Any] = {}
    else:
        try:
            (
                retrieved_nodes,
                metadata,
            ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
        except BaseException as e:
            # if handle_sql_errors is True, then return error message
            if self._handle_sql_errors:
                err_node = TextNode(text=f"Error: {e!s}")
                retrieved_nodes = [NodeWithScore(node=err_node)]
                metadata = {}
            else:
                raise
    return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

SQLParserMode #

Bases: str, Enum

SQL Parser Mode.

Source code in llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
105
106
107
108
109
class SQLParserMode(str, Enum):
    """SQL Parser Mode."""

    DEFAULT = "default"
    PGVECTOR = "pgvector"

SQLRetriever #

Bases: BaseRetriever

SQL Retriever.

Retrieves via raw SQL statements.

Parameters:

Name Type Description Default
sql_database SQLDatabase

SQL database.

required
return_raw bool

Whether to return raw results or format results. Defaults to True.

True
Source code in llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class SQLRetriever(BaseRetriever):
    """SQL Retriever.

    Retrieves via raw SQL statements.

    Args:
        sql_database (SQLDatabase): SQL database.
        return_raw (bool): Whether to return raw results or format results.
            Defaults to True.

    """

    def __init__(
        self,
        sql_database: SQLDatabase,
        return_raw: bool = True,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ) -> None:
        """Initialize params."""
        self._sql_database = sql_database
        self._return_raw = return_raw
        super().__init__(callback_manager)

    def _format_node_results(
        self, results: List[List[Any]], col_keys: List[str]
    ) -> List[NodeWithScore]:
        """Format node results."""
        nodes = []
        for result in results:
            # associate column keys with result tuple
            metadata = dict(zip(col_keys, result))
            # NOTE: leave text field blank for now
            text_node = TextNode(
                text="",
                metadata=metadata,
            )
            nodes.append(NodeWithScore(node=text_node))
        return nodes

    def retrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """Retrieve with metadata."""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        raw_response_str, metadata = self._sql_database.run_sql(query_bundle.query_str)
        if self._return_raw:
            return [NodeWithScore(node=TextNode(text=raw_response_str))], metadata
        else:
            # return formatted
            results = metadata["result"]
            col_keys = metadata["col_keys"]
            return self._format_node_results(results, col_keys), metadata

    async def aretrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        return self.retrieve_with_metadata(str_or_query_bundle)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle)
        return retrieved_nodes

retrieve_with_metadata #

retrieve_with_metadata(str_or_query_bundle: QueryType) -> Tuple[List[NodeWithScore], Dict]

Retrieve with metadata.

Source code in llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def retrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """Retrieve with metadata."""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    raw_response_str, metadata = self._sql_database.run_sql(query_bundle.query_str)
    if self._return_raw:
        return [NodeWithScore(node=TextNode(text=raw_response_str))], metadata
    else:
        # return formatted
        results = metadata["result"]
        col_keys = metadata["col_keys"]
        return self._format_node_results(results, col_keys), metadata