📘 **TELUS Agriculture & Consumer Goods** 如何通过 **Haystack Agents** 转变促销交易

与您的 SQL 数据库对话


作者:Tuana Celik ( X, LI)

在此示例中,我们正在查询 SQL 数据库!

资源:

King (14).png

安装依赖项

在此演示中,我们使用 SQLite。

本节的前几个代码单元格会抓取有关“缺勤”的 CSV 文件,并从中创建一个 SQL 表。

!pip install git+https://github.com/deepset-ai/haystack.git@main#egg=haystack-ai
from urllib.request import urlretrieve
from zipfile import ZipFile
import pandas as pd

url = "https://archive.ics.uci.edu/static/public/445/absenteeism+at+work.zip"

# download the file
urlretrieve(url, "Absenteeism_at_work_AAA.zip")

print("Extracting the Absenteeism at work dataset...")
# Extract the CSV file
with ZipFile("Absenteeism_at_work_AAA.zip", 'r') as zf:
    zf.extractall()

# Check the extracted CSV file name (in this case, it's "Absenteeism_at_work.csv")
csv_file_name = "Absenteeism_at_work.csv"

print("Cleaning up the Absenteeism at work dataset...")
# Data clean up
df = pd.read_csv(csv_file_name, sep=";")
df.columns = df.columns.str.replace(' ', '_')
df.columns = df.columns.str.replace('/', '_')
Extracting the Absenteeism at work dataset...
Cleaning up the Absenteeism at work dataset...
columns = df.columns.to_list()
columns = ', '.join(columns)
columns
'ID, Reason_for_absence, Month_of_absence, Day_of_the_week, Seasons, Transportation_expense, Distance_from_Residence_to_Work, Service_time, Age, Work_load_Average_day_, Hit_target, Disciplinary_failure, Education, Son, Social_drinker, Social_smoker, Pet, Weight, Height, Body_mass_index, Absenteeism_time_in_hours'
import sqlite3

connection = sqlite3.connect('absenteeism.db')
print("Opened database successfully");

connection.execute('''CREATE TABLE IF NOT EXISTS absenteeism (ID integer,
                        Reason_for_absence integer,
                        Month_of_absence integer,
                        Day_of_the_week integer,
                        Seasons integer,
                        Transportation_expense integer,
                        Distance_from_Residence_to_Work integer,
                        Service_time integer,
                        Age integer,
                        Work_load_Average_day_ integer,
                        Hit_target integer,
                        Disciplinary_failure integer,
                        Education integer,
                        Son integer,
                        Social_drinker integer,
                        Social_smoker integer,
                        Pet integer,
                        Weight integer,
                        Height integer,
                        Body_mass_index integer,
                        Absenteeism_time_in_hours integer);''')
connection.commit()
Opened database successfully
df.to_sql('absenteeism', connection, if_exists='replace', index = False)
740
connection.close()

创建 SQL 查询组件

在这里,我们正在创建一个名为 SQLQuery 的自定义组件,这样我们就可以在 Haystack 管道中像使用其他组件一样使用它(例如检索器、生成器等)。此组件只做一件事:

  • 接受 queries(SQL 查询)
  • 使用这些 SQL 查询数据库并返回数据库中的结果。
from typing import List
from haystack import component

@component
class SQLQuery:

    def __init__(self, sql_database: str):
      self.connection = sqlite3.connect(sql_database, check_same_thread=False)

    @component.output_types(results=List[str], queries=List[str])
    def run(self, queries: List[str]):
        results = []
        for query in queries:
          result = pd.read_sql(query, self.connection)
          results.append(f"{result}")
        return {"results": results, "queries": queries}

尝试 SQLQuery 组件

sql_query = SQLQuery('absenteeism.db')
result = sql_query.run(queries=['SELECT Age, SUM(Absenteeism_time_in_hours) as Total_Absenteeism_Hours FROM absenteeism WHERE Disciplinary_failure = 0 GROUP BY Age ORDER BY Total_Absenteeism_Hours DESC LIMIT 3;'])
print(result["results"][0])
   Age  Total_Absenteeism_Hours
0   28                      651
1   33                      538
2   38                      482

用自然语言查询 SQL 数据库

在本节中,我们正在构建一个简单的管道,该管道可以:

  • 接受自然语言问题
  • 将这些问题转换为 SQL 查询
  • 使用 SQLQuery 组件查询我们的数据库

缺点:如果问一个完全不相关的问题,而该问题无法用我们现有的数据库来回答,这个管道仍然会运行。请注意 SQLQuery 组件在这些情况下如何抛出错误。

import os
from getpass import getpass

os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")
OpenAI API Key: ··········
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.generators.openai import OpenAIGenerator

prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
            The query is to be answered for the table is called 'absenteeism' with the following
            Columns: {{columns}};
            Answer:""")
sql_query = SQLQuery('absenteeism.db')
llm = OpenAIGenerator(model="gpt-4")

sql_pipeline = Pipeline()
sql_pipeline.add_component("prompt", prompt)
sql_pipeline.add_component("llm", llm)
sql_pipeline.add_component("sql_querier", sql_query)

sql_pipeline.connect("prompt", "llm")
sql_pipeline.connect("llm.replies", "sql_querier.queries")

# If you want to draw the pipeline, uncomment below 👇
sql_pipeline.show()
result = sql_pipeline.run({"prompt": {"question": "On which days of the week does the average absenteeism time exceed 4 hours?",
                            "columns": columns}})

print(result["sql_querier"]["results"][0])
   Day_of_the_week
0                2
1                3
2                4
3                5
4                6

跳过不相关的问题:添加条件

现在,让我们创建另一个管道,以避免在问题与数据库中的信息不相关时查询数据库。为此,我们做了一些事情:

  • 我们修改了提示,使其在给定数据库及其 columns 无法回答问题时返回 no_answer
  • 我们添加了一个条件路由器,仅当问题被评估为可回答时,才将查询路由到 SQLQuery 组件。
  • 我们添加了 fallback_promptfallback_llm 来返回有关问题无法回答的事实以及原因的声明。此管道分支仅在问题无法回答时运行。
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.generators.openai import OpenAIGenerator
from haystack.components.routers import ConditionalRouter

prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
            If the question cannot be answered given the provided table and columns, return 'no_answer'
            The query is to be answered for the table is called 'absenteeism' with the following
            Columns: {{columns}};
            Answer:""")

llm = OpenAIGenerator(model="gpt-4")
sql_query = SQLQuery('absenteeism.db')

routes = [
     {
        "condition": "{{'no_answer' not in replies[0]}}",
        "output": "{{replies}}",
        "output_name": "sql",
        "output_type": List[str],
    },
    {
        "condition": "{{'no_answer' in replies[0]}}",
        "output": "{{question}}",
        "output_name": "go_to_fallback",
        "output_type": str,
    },
]

router = ConditionalRouter(routes)

fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answerwed with the given table.
                                            The query was: {{question}} and the table had columns: {{columns}}.
                                            Let the user know why the question cannot be answered""")
fallback_llm = OpenAIGenerator(model="gpt-4")

conditional_sql_pipeline = Pipeline()
conditional_sql_pipeline.add_component("prompt", prompt)
conditional_sql_pipeline.add_component("llm", llm)
conditional_sql_pipeline.add_component("router", router)
conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
conditional_sql_pipeline.add_component("sql_querier", sql_query)

conditional_sql_pipeline.connect("prompt", "llm")
conditional_sql_pipeline.connect("llm.replies", "router.replies")
conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")

#if you want to draw the pipeline, uncomment below 👇

#conditional_sql_pipeline.show()
question = "When is my birthday?"
result = conditional_sql_pipeline.run({"prompt": {"question": question,
                                                  "columns": columns},
                                       "router": {"question": question},
                                       "fallback_prompt": {"columns": columns}})
if 'sql_querier' in result:
  print(result['sql_querier']['results'][0])
elif 'fallback_llm' in result:
  print(result['fallback_llm']['replies'][0])
The query cannot be answered as the provided table does not contain information regarding the user's personal data such as birthdays. The table primarily focuses on absence-related data for presumably work or similar situations. Please provide the relevant data to get the accurate answer.

函数调用以查询 SQL 数据库

现在让我们尝试更有趣的事情。我们不使用组件,而是将 SQL 查询作为函数提供。由于我们已经构建了它,我们可以简单地将 SQLQuery 组件包装成一个函数 👇

sql_query = SQLQuery('absenteeism.db')

def sql_query_func(queries: List[str]):
    try:
      result = sql_query.run(queries)
      return {"reply": result["results"][0]}

    except Exception as e:
      reply = f"""There was an error running the SQL Query = {queries}
              The error is {e},
              You should probably try again.
              """
      return {"reply": reply}

定义工具

现在,让我们将此函数作为工具提供。下面,我们使用 OpenAI 作为演示目的,因此我们遵守他们的函数定义模式 👇

tools = [
    {
        "type": "function",
        "function": {
            "name": "sql_query_func",
            "description": f"This a tool useful to query a SQL table called 'absenteeism' with the following Columns: {columns}",
            "parameters": {
                "type": "object",
                "properties": {
                    "queries": {
                        "type": "array",
                        "description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement",
                        "items": {
                            "type": "string",
                        }
                    }
                },
                "required": ["question"],
            },
        },
    }
]

尝试工具

from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk

messages = [
    ChatMessage.from_system(
        "You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'absenteeism'"
    ),
    ChatMessage.from_user("On which days of the week does the average absenteeism time exceed 4 hours??"),
]

chat_generator = OpenAIChatGenerator(model="gpt-4", streaming_callback=print_streaming_chunk)
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})
print(response)
{'replies': [ChatMessage(content='[{"index": 0, "id": "call_fRYwYg6iAqroHwYzPD6UxOVg", "function": {"arguments": "{\\n  \\"queries\\": [\\"SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) AS Average_Absenteeism_Hours FROM absenteeism GROUP BY Day_of_the_week HAVING AVG(Absenteeism_time_in_hours) > 4\\"]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {}})]}
import json

## Parse function calling information
function_call = json.loads(response["replies"][0].text)[0]
function_name = function_call["function"]["name"]
function_args = json.loads(function_call["function"]["arguments"])
print("Function Name:", function_name)
print("Function Arguments:", function_args)

## Find the correspoding function and call it with the given arguments
available_functions = {"sql_query_func": sql_query_func}
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
print("Function Response:", function_response)
Function Name: sql_query_func
Function Arguments: {'queries': ['SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) AS Average_Absenteeism_Hours FROM absenteeism GROUP BY Day_of_the_week HAVING AVG(Absenteeism_time_in_hours) > 4']}
Function Response: {'reply': '   Day_of_the_week  Average_Absenteeism_Hours\n0                2                   9.248447\n1                3                   7.980519\n2                4                   7.147436\n3                5                   4.424000\n4                6                   5.125000'}

构建一个与 SQL 应用程序的聊天

首先,让我们安装 Gradio,我们将使用它来构建我们的迷你应用程序。

!pip install gradio
import gradio as gr
import json

from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator

chat_generator = OpenAIChatGenerator(model="gpt-4")
response = None
messages = [
    ChatMessage.from_system(
        "You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'absenteeism'"
    )
]


def chatbot_with_fc(message, history):
    available_functions = {"sql_query_func": sql_query_func}
    messages.append(ChatMessage.from_user(message))
    response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})

    while True:
        # if OpenAI response is a tool call
        if response and response["replies"][0].meta["finish_reason"] == "tool_calls":
            function_calls = json.loads(response["replies"][0].text)
            for function_call in function_calls:
                ## Parse function calling information
                function_name = function_call["function"]["name"]
                function_args = json.loads(function_call["function"]["arguments"])

                ## Find the correspoding function and call it with the given arguments
                function_to_call = available_functions[function_name]
                function_response = function_to_call(**function_args)
                ## Append function response to the messages list using `ChatMessage.from_function`
                messages.append(ChatMessage.from_function(content=function_response['reply'], name=function_name))
                response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})

        # Regular Conversation
        else:
            messages.append(response["replies"][0])
            break
    return response["replies"][0].text
demo = gr.ChatInterface(
    fn=chatbot_with_fc,
    examples=[
        "Find the top 3 ages with the highest total absenteeism hours, excluding disciplinary failures",
        "On which days of the week does the average absenteeism time exceed 4 hours?",
        "Who lives in London?",
    ],
    title="Chat with your SQL Database",
)
demo.launch(debug=True)
Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://83eb0414c1916d8ee7.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://hugging-face.cn/spaces)
ChatMessage(content='[{"id": "call_Uu8QXlIsJfYCULD4Q0bEcAtP", "function": {"arguments": "{\\n  \\"queries\\": [\\n    \\"SELECT Age, SUM(Absenteeism_time_in_hours) as Total_Absenteeism_Hours FROM absenteeism WHERE Disciplinary_failure = 0 GROUP BY Age ORDER BY Total_Absenteeism_Hours DESC LIMIT 3\\"\\n  ]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {'completion_tokens': 68, 'prompt_tokens': 207, 'total_tokens': 275}})
   Age  Total_Absenteeism_Hours
0   28                      651
1   33                      538
2   38                      482
ChatMessage(content='[{"id": "call_t8bjUHMvHHrXReB2qm3iVkNF", "function": {"arguments": "{\\n\\"queries\\": [\\"SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) as average_absenteeism_time FROM absenteeism GROUP BY Day_of_the_week HAVING average_absenteeism_time > 4\\"]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {'completion_tokens': 57, 'prompt_tokens': 320, 'total_tokens': 377}})
   Day_of_the_week  average_absenteeism_time
0                2                  9.248447
1                3                  7.980519
2                4                  7.147436
3                5                  4.424000
4                6                  5.125000
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://83eb0414c1916d8ee7.gradio.live