与您的 SQL 数据库对话
最后更新:2024 年 12 月 10 日
在此示例中,我们正在查询 SQL 数据库!
资源:
安装依赖项
在此演示中,我们使用 SQLite。
本节的前几个代码单元格会抓取有关“缺勤”的 CSV 文件,并从中创建一个 SQL 表。
!pip install git+https://github.com/deepset-ai/haystack.git@main#egg=haystack-ai
from urllib.request import urlretrieve
from zipfile import ZipFile
import pandas as pd
url = "https://archive.ics.uci.edu/static/public/445/absenteeism+at+work.zip"
# download the file
urlretrieve(url, "Absenteeism_at_work_AAA.zip")
print("Extracting the Absenteeism at work dataset...")
# Extract the CSV file
with ZipFile("Absenteeism_at_work_AAA.zip", 'r') as zf:
zf.extractall()
# Check the extracted CSV file name (in this case, it's "Absenteeism_at_work.csv")
csv_file_name = "Absenteeism_at_work.csv"
print("Cleaning up the Absenteeism at work dataset...")
# Data clean up
df = pd.read_csv(csv_file_name, sep=";")
df.columns = df.columns.str.replace(' ', '_')
df.columns = df.columns.str.replace('/', '_')
Extracting the Absenteeism at work dataset...
Cleaning up the Absenteeism at work dataset...
columns = df.columns.to_list()
columns = ', '.join(columns)
columns
'ID, Reason_for_absence, Month_of_absence, Day_of_the_week, Seasons, Transportation_expense, Distance_from_Residence_to_Work, Service_time, Age, Work_load_Average_day_, Hit_target, Disciplinary_failure, Education, Son, Social_drinker, Social_smoker, Pet, Weight, Height, Body_mass_index, Absenteeism_time_in_hours'
import sqlite3
connection = sqlite3.connect('absenteeism.db')
print("Opened database successfully");
connection.execute('''CREATE TABLE IF NOT EXISTS absenteeism (ID integer,
Reason_for_absence integer,
Month_of_absence integer,
Day_of_the_week integer,
Seasons integer,
Transportation_expense integer,
Distance_from_Residence_to_Work integer,
Service_time integer,
Age integer,
Work_load_Average_day_ integer,
Hit_target integer,
Disciplinary_failure integer,
Education integer,
Son integer,
Social_drinker integer,
Social_smoker integer,
Pet integer,
Weight integer,
Height integer,
Body_mass_index integer,
Absenteeism_time_in_hours integer);''')
connection.commit()
Opened database successfully
df.to_sql('absenteeism', connection, if_exists='replace', index = False)
740
connection.close()
创建 SQL 查询组件
在这里,我们正在创建一个名为 SQLQuery 的自定义组件,这样我们就可以在 Haystack 管道中像使用其他组件一样使用它(例如检索器、生成器等)。此组件只做一件事:
- 接受
queries(SQL 查询) - 使用这些 SQL 查询数据库并返回数据库中的结果。
from typing import List
from haystack import component
@component
class SQLQuery:
def __init__(self, sql_database: str):
self.connection = sqlite3.connect(sql_database, check_same_thread=False)
@component.output_types(results=List[str], queries=List[str])
def run(self, queries: List[str]):
results = []
for query in queries:
result = pd.read_sql(query, self.connection)
results.append(f"{result}")
return {"results": results, "queries": queries}
尝试 SQLQuery 组件
sql_query = SQLQuery('absenteeism.db')
result = sql_query.run(queries=['SELECT Age, SUM(Absenteeism_time_in_hours) as Total_Absenteeism_Hours FROM absenteeism WHERE Disciplinary_failure = 0 GROUP BY Age ORDER BY Total_Absenteeism_Hours DESC LIMIT 3;'])
print(result["results"][0])
Age Total_Absenteeism_Hours
0 28 651
1 33 538
2 38 482
用自然语言查询 SQL 数据库
在本节中,我们正在构建一个简单的管道,该管道可以:
- 接受自然语言问题
- 将这些问题转换为 SQL 查询
- 使用
SQLQuery组件查询我们的数据库
缺点:如果问一个完全不相关的问题,而该问题无法用我们现有的数据库来回答,这个管道仍然会运行。请注意 SQLQuery 组件在这些情况下如何抛出错误。
import os
from getpass import getpass
os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")
OpenAI API Key: ··········
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.generators.openai import OpenAIGenerator
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
The query is to be answered for the table is called 'absenteeism' with the following
Columns: {{columns}};
Answer:""")
sql_query = SQLQuery('absenteeism.db')
llm = OpenAIGenerator(model="gpt-4")
sql_pipeline = Pipeline()
sql_pipeline.add_component("prompt", prompt)
sql_pipeline.add_component("llm", llm)
sql_pipeline.add_component("sql_querier", sql_query)
sql_pipeline.connect("prompt", "llm")
sql_pipeline.connect("llm.replies", "sql_querier.queries")
# If you want to draw the pipeline, uncomment below 👇
sql_pipeline.show()
result = sql_pipeline.run({"prompt": {"question": "On which days of the week does the average absenteeism time exceed 4 hours?",
"columns": columns}})
print(result["sql_querier"]["results"][0])
Day_of_the_week
0 2
1 3
2 4
3 5
4 6
跳过不相关的问题:添加条件
现在,让我们创建另一个管道,以避免在问题与数据库中的信息不相关时查询数据库。为此,我们做了一些事情:
- 我们修改了提示,使其在给定数据库及其
columns无法回答问题时返回no_answer。 - 我们添加了一个条件路由器,仅当问题被评估为可回答时,才将查询路由到
SQLQuery组件。 - 我们添加了
fallback_prompt和fallback_llm来返回有关问题无法回答的事实以及原因的声明。此管道分支仅在问题无法回答时运行。
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.generators.openai import OpenAIGenerator
from haystack.components.routers import ConditionalRouter
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
If the question cannot be answered given the provided table and columns, return 'no_answer'
The query is to be answered for the table is called 'absenteeism' with the following
Columns: {{columns}};
Answer:""")
llm = OpenAIGenerator(model="gpt-4")
sql_query = SQLQuery('absenteeism.db')
routes = [
{
"condition": "{{'no_answer' not in replies[0]}}",
"output": "{{replies}}",
"output_name": "sql",
"output_type": List[str],
},
{
"condition": "{{'no_answer' in replies[0]}}",
"output": "{{question}}",
"output_name": "go_to_fallback",
"output_type": str,
},
]
router = ConditionalRouter(routes)
fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answerwed with the given table.
The query was: {{question}} and the table had columns: {{columns}}.
Let the user know why the question cannot be answered""")
fallback_llm = OpenAIGenerator(model="gpt-4")
conditional_sql_pipeline = Pipeline()
conditional_sql_pipeline.add_component("prompt", prompt)
conditional_sql_pipeline.add_component("llm", llm)
conditional_sql_pipeline.add_component("router", router)
conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
conditional_sql_pipeline.add_component("sql_querier", sql_query)
conditional_sql_pipeline.connect("prompt", "llm")
conditional_sql_pipeline.connect("llm.replies", "router.replies")
conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
#if you want to draw the pipeline, uncomment below 👇
#conditional_sql_pipeline.show()
question = "When is my birthday?"
result = conditional_sql_pipeline.run({"prompt": {"question": question,
"columns": columns},
"router": {"question": question},
"fallback_prompt": {"columns": columns}})
if 'sql_querier' in result:
print(result['sql_querier']['results'][0])
elif 'fallback_llm' in result:
print(result['fallback_llm']['replies'][0])
The query cannot be answered as the provided table does not contain information regarding the user's personal data such as birthdays. The table primarily focuses on absence-related data for presumably work or similar situations. Please provide the relevant data to get the accurate answer.
函数调用以查询 SQL 数据库
现在让我们尝试更有趣的事情。我们不使用组件,而是将 SQL 查询作为函数提供。由于我们已经构建了它,我们可以简单地将 SQLQuery 组件包装成一个函数 👇
sql_query = SQLQuery('absenteeism.db')
def sql_query_func(queries: List[str]):
try:
result = sql_query.run(queries)
return {"reply": result["results"][0]}
except Exception as e:
reply = f"""There was an error running the SQL Query = {queries}
The error is {e},
You should probably try again.
"""
return {"reply": reply}
定义工具
现在,让我们将此函数作为工具提供。下面,我们使用 OpenAI 作为演示目的,因此我们遵守他们的函数定义模式 👇
tools = [
{
"type": "function",
"function": {
"name": "sql_query_func",
"description": f"This a tool useful to query a SQL table called 'absenteeism' with the following Columns: {columns}",
"parameters": {
"type": "object",
"properties": {
"queries": {
"type": "array",
"description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement",
"items": {
"type": "string",
}
}
},
"required": ["question"],
},
},
}
]
尝试工具
from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
messages = [
ChatMessage.from_system(
"You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'absenteeism'"
),
ChatMessage.from_user("On which days of the week does the average absenteeism time exceed 4 hours??"),
]
chat_generator = OpenAIChatGenerator(model="gpt-4", streaming_callback=print_streaming_chunk)
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})
print(response)
{'replies': [ChatMessage(content='[{"index": 0, "id": "call_fRYwYg6iAqroHwYzPD6UxOVg", "function": {"arguments": "{\\n \\"queries\\": [\\"SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) AS Average_Absenteeism_Hours FROM absenteeism GROUP BY Day_of_the_week HAVING AVG(Absenteeism_time_in_hours) > 4\\"]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {}})]}
import json
## Parse function calling information
function_call = json.loads(response["replies"][0].text)[0]
function_name = function_call["function"]["name"]
function_args = json.loads(function_call["function"]["arguments"])
print("Function Name:", function_name)
print("Function Arguments:", function_args)
## Find the correspoding function and call it with the given arguments
available_functions = {"sql_query_func": sql_query_func}
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
print("Function Response:", function_response)
Function Name: sql_query_func
Function Arguments: {'queries': ['SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) AS Average_Absenteeism_Hours FROM absenteeism GROUP BY Day_of_the_week HAVING AVG(Absenteeism_time_in_hours) > 4']}
Function Response: {'reply': ' Day_of_the_week Average_Absenteeism_Hours\n0 2 9.248447\n1 3 7.980519\n2 4 7.147436\n3 5 4.424000\n4 6 5.125000'}
构建一个与 SQL 应用程序的聊天
首先,让我们安装 Gradio,我们将使用它来构建我们的迷你应用程序。
!pip install gradio
import gradio as gr
import json
from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator
chat_generator = OpenAIChatGenerator(model="gpt-4")
response = None
messages = [
ChatMessage.from_system(
"You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'absenteeism'"
)
]
def chatbot_with_fc(message, history):
available_functions = {"sql_query_func": sql_query_func}
messages.append(ChatMessage.from_user(message))
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})
while True:
# if OpenAI response is a tool call
if response and response["replies"][0].meta["finish_reason"] == "tool_calls":
function_calls = json.loads(response["replies"][0].text)
for function_call in function_calls:
## Parse function calling information
function_name = function_call["function"]["name"]
function_args = json.loads(function_call["function"]["arguments"])
## Find the correspoding function and call it with the given arguments
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
## Append function response to the messages list using `ChatMessage.from_function`
messages.append(ChatMessage.from_function(content=function_response['reply'], name=function_name))
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})
# Regular Conversation
else:
messages.append(response["replies"][0])
break
return response["replies"][0].text
demo = gr.ChatInterface(
fn=chatbot_with_fc,
examples=[
"Find the top 3 ages with the highest total absenteeism hours, excluding disciplinary failures",
"On which days of the week does the average absenteeism time exceed 4 hours?",
"Who lives in London?",
],
title="Chat with your SQL Database",
)
demo.launch(debug=True)
Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://83eb0414c1916d8ee7.gradio.live
This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://hugging-face.cn/spaces)
ChatMessage(content='[{"id": "call_Uu8QXlIsJfYCULD4Q0bEcAtP", "function": {"arguments": "{\\n \\"queries\\": [\\n \\"SELECT Age, SUM(Absenteeism_time_in_hours) as Total_Absenteeism_Hours FROM absenteeism WHERE Disciplinary_failure = 0 GROUP BY Age ORDER BY Total_Absenteeism_Hours DESC LIMIT 3\\"\\n ]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {'completion_tokens': 68, 'prompt_tokens': 207, 'total_tokens': 275}})
Age Total_Absenteeism_Hours
0 28 651
1 33 538
2 38 482
ChatMessage(content='[{"id": "call_t8bjUHMvHHrXReB2qm3iVkNF", "function": {"arguments": "{\\n\\"queries\\": [\\"SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) as average_absenteeism_time FROM absenteeism GROUP BY Day_of_the_week HAVING average_absenteeism_time > 4\\"]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {'completion_tokens': 57, 'prompt_tokens': 320, 'total_tokens': 377}})
Day_of_the_week average_absenteeism_time
0 2 9.248447
1 3 7.980519
2 4 7.147436
3 5 4.424000
4 6 5.125000
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://83eb0414c1916d8ee7.gradio.live
