|
|
@@ -0,0 +1,359 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "id": "2671054f",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "C:\\Users\\34450\\AppData\\Local\\Temp\\ipykernel_16268\\3464868009.py:2: DeprecationWarning: `langchain-community` is being sunset and is no longer actively maintained. See https://github.com/langchain-ai/langchain-community/issues/674 for details and migration guidance toward standalone integration packages.\n",
|
|
|
+ " from langchain_community.utilities import SQLDatabase\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "数据库连接成功\n",
|
|
|
+ " 可用表:['employees', 'orders', 'products']\n",
|
|
|
+ "\n",
|
|
|
+ "SQL 工具包已加载(4 个工具):\n",
|
|
|
+ " - sql_db_query\n",
|
|
|
+ " - sql_db_schema\n",
|
|
|
+ " - sql_db_list_tables\n",
|
|
|
+ " - sql_db_query_checker\n",
|
|
|
+ "\n",
|
|
|
+ "NL2SQL Agent 创建完成,可以开始提问了!\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "import os\n",
|
|
|
+ "from langchain_community.utilities import SQLDatabase\n",
|
|
|
+ "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
|
|
|
+ "from langchain_openai import ChatOpenAI\n",
|
|
|
+ "from langchain.agents import create_agent # langchain 1.3.1 的新 API\n",
|
|
|
+ "from dotenv import load_dotenv \n",
|
|
|
+ "\n",
|
|
|
+ "load_dotenv() # 默认加载当前目录或父目录中的 .env 文件\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 第一步:连接数据库\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# SQLDatabase.from_uri 接受标准的数据库连接 URI\n",
|
|
|
+ "# LangChain 内部会用 SQLAlchemy 管理连接池\n",
|
|
|
+ "db_uri = (\n",
|
|
|
+ " f\"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}\"\n",
|
|
|
+ " f\"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}\"\n",
|
|
|
+ ")\n",
|
|
|
+ "db = SQLDatabase.from_uri(db_uri)\n",
|
|
|
+ "\n",
|
|
|
+ "# 验证连接:打印可用的表名\n",
|
|
|
+ "print(f\"数据库连接成功\")\n",
|
|
|
+ "print(f\" 可用表:{db.get_usable_table_names()}\")\n",
|
|
|
+ "\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 第二步:初始化大模型\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "llm = ChatOpenAI(\n",
|
|
|
+ " model=os.getenv(\"DEEPSEEK_MODEL\"),\n",
|
|
|
+ " api_key=os.getenv(\"DEEPSEEK_API_KEY\"),\n",
|
|
|
+ " base_url=os.getenv(\"DEEPSEEK_BASE_URL\"),\n",
|
|
|
+ " temperature=0, # SQL 生成需要确定性输出\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 第三步:创建 SQL 工具包\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# SQLDatabaseToolkit 会自动注册 4 个工具:\n",
|
|
|
+ "# 1. sql_db_list_tables — 列出数据库中所有表\n",
|
|
|
+ "# 2. sql_db_schema — 获取指定表的 DDL 结构\n",
|
|
|
+ "# 3. sql_db_query_checker — 检查 SQL 语法是否正确\n",
|
|
|
+ "# 4. sql_db_query — 执行 SQL 并返回结果\n",
|
|
|
+ "toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n",
|
|
|
+ "tools = toolkit.get_tools()\n",
|
|
|
+ "\n",
|
|
|
+ "print(f\"\\nSQL 工具包已加载({len(tools)} 个工具):\")\n",
|
|
|
+ "for t in tools:\n",
|
|
|
+ " print(f\" - {t.name}\")\n",
|
|
|
+ "\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 第四步:定义 System Prompt(Agent 的\"岗位说明书\")\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "SQL_AGENT_PROMPT = \"\"\"你是一名专业的 SQL 数据分析师。\n",
|
|
|
+ "\n",
|
|
|
+ "## 工作流程\n",
|
|
|
+ "1. 先用 sql_db_list_tables 查看数据库中有哪些表\n",
|
|
|
+ "2. 用 sql_db_schema 获取相关表的字段结构和类型\n",
|
|
|
+ "3. 生成 SQL 之前,用 sql_db_query_checker 检查语法\n",
|
|
|
+ "4. 确认无误后,用 sql_db_query 执行查询\n",
|
|
|
+ "5. 用中文总结查询结果,给出简洁的业务洞察\n",
|
|
|
+ "\n",
|
|
|
+ "## 约束\n",
|
|
|
+ "- 只使用数据库中实际存在的表和字段,不要凭空编造\n",
|
|
|
+ "- 单次查询结果限制在 50 条以内\n",
|
|
|
+ "- 如果查询出错,分析错误原因后重新生成 SQL\n",
|
|
|
+ "- 回答要简洁专业,不要啰嗦\n",
|
|
|
+ "\"\"\"\n",
|
|
|
+ "\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 第五步:组装 Agent(langchain 1.3.1 一行搞定)\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# create_agent 返回一个可直接调用的 Runnable,无需再套 AgentExecutor\n",
|
|
|
+ "# 它内部自动处理工具调用、循环迭代、错误重试等逻辑\n",
|
|
|
+ "sql_agent = create_agent(\n",
|
|
|
+ " model=llm,\n",
|
|
|
+ " tools=tools,\n",
|
|
|
+ " system_prompt=SQL_AGENT_PROMPT,\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "print(\"\\nNL2SQL Agent 创建完成,可以开始提问了!\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "id": "bd1fef44",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "回答:公司共有 **5 名员工**。\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# 简单问题:Agent 只需一条 COUNT SQL 就能搞定\n",
|
|
|
+ "# create_agent 的 invoke 接口使用 messages 格式\n",
|
|
|
+ "response = sql_agent.invoke({\n",
|
|
|
+ " \"messages\": [{\"role\": \"user\", \"content\": \"公司一共有多少名员工?\"}]\n",
|
|
|
+ "})\n",
|
|
|
+ "\n",
|
|
|
+ "# 最终回答在最后一条消息里\n",
|
|
|
+ "final_msg = response[\"messages\"][-1]\n",
|
|
|
+ "print(f\"回答:{final_msg.content}\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "id": "a7025030",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "回答:## 查询结果\n",
|
|
|
+ "\n",
|
|
|
+ "技术部共有 **2 名员工**,按薪资从高到低排序如下:\n",
|
|
|
+ "\n",
|
|
|
+ "| 姓名 | 薪资 |\n",
|
|
|
+ "|------|------|\n",
|
|
|
+ "| 张三 | 20,000.00 |\n",
|
|
|
+ "| 王五 | 16,000.00 |\n",
|
|
|
+ "\n",
|
|
|
+ "**业务洞察:**\n",
|
|
|
+ "- 技术部员工平均薪资为 **18,000 元**。\n",
|
|
|
+ "- 张三薪资最高(20,000元),比第二名王五高出 **25%**。\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# 中等难度:需要 JOIN + WHERE + GROUP BY\n",
|
|
|
+ "response = sql_agent.invoke({\n",
|
|
|
+ " \"messages\": [{\"role\": \"user\", \"content\": \"列出技术部所有员工的姓名和薪资,按薪资从高到低排序\"}]\n",
|
|
|
+ "})\n",
|
|
|
+ "final_msg = response[\"messages\"][-1]\n",
|
|
|
+ "print(f\"回答:{final_msg.content}\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "id": "cb41a941",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "回答:查询结果如下:\n",
|
|
|
+ "\n",
|
|
|
+ "---\n",
|
|
|
+ "\n",
|
|
|
+ "### 📊 销售部订单统计\n",
|
|
|
+ "\n",
|
|
|
+ "| 指标 | 数据 |\n",
|
|
|
+ "|------|------|\n",
|
|
|
+ "| **订单总数** | **3 笔** |\n",
|
|
|
+ "| **订单总金额** | **¥14,979.00** |\n",
|
|
|
+ "\n",
|
|
|
+ "### 👥 员工明细\n",
|
|
|
+ "\n",
|
|
|
+ "| 员工 | 订单数 | 金额 |\n",
|
|
|
+ "|-----|:-----:|:----:|\n",
|
|
|
+ "| **李四** | 2 笔 | ¥11,985.00 |\n",
|
|
|
+ "| **钱七** | 1 笔 | ¥2,994.00 |\n",
|
|
|
+ "\n",
|
|
|
+ "**业务洞察:**\n",
|
|
|
+ "- 销售部共 **2 名员工**,累计下单 **3 笔**,总金额 **14,979 元**。\n",
|
|
|
+ "- 李四贡献了总金额的 **80%**(11,985 元),是销售部的核心业务人员。\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# 高难度:需要 JOIN 三张表 + 聚合计算\n",
|
|
|
+ "response = sql_agent.invoke({\n",
|
|
|
+ " \"messages\": [{\"role\": \"user\", \"content\": \"销售部的员工总共下了多少订单?订单总金额是多少?\"}]\n",
|
|
|
+ "})\n",
|
|
|
+ "final_msg = response[\"messages\"][-1]\n",
|
|
|
+ "print(f\"回答:{final_msg.content}\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "id": "5ce44c51",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\n",
|
|
|
+ "最终回答:每个部门各有多少人?\n",
|
|
|
+ "\n",
|
|
|
+ "步骤 2 [AIMessage]:\n",
|
|
|
+ " 调用工具:sql_db_list_tables\n",
|
|
|
+ " 参数:{}\n",
|
|
|
+ " 工具返回:employees, orders, products\n",
|
|
|
+ "\n",
|
|
|
+ "步骤 4 [AIMessage]:\n",
|
|
|
+ " 调用工具:sql_db_schema\n",
|
|
|
+ " 参数:{'table_names': 'employees'}\n",
|
|
|
+ " 工具返回:\n",
|
|
|
+ "CREATE TABLE employees (\n",
|
|
|
+ "\tid INTEGER NOT NULL AUTO_INCREMENT, \n",
|
|
|
+ "\tname VARCHAR(50) NOT NULL, \n",
|
|
|
+ "\tdepartment VARCHAR(50) NOT NULL, \n",
|
|
|
+ "\tsalary DECIMAL(10, 2) NOT NULL, \n",
|
|
|
+ "\thire_date DATE, \n",
|
|
|
+ "\tPRIMARY KEY (id)\n",
|
|
|
+ ")E\n",
|
|
|
+ "\n",
|
|
|
+ "步骤 6 [AIMessage]:\n",
|
|
|
+ " 调用工具:sql_db_query_checker\n",
|
|
|
+ " 参数:{'query': 'SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC'}\n",
|
|
|
+ " 工具返回:SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC\n",
|
|
|
+ "\n",
|
|
|
+ "步骤 8 [AIMessage]:\n",
|
|
|
+ " 调用工具:sql_db_query\n",
|
|
|
+ " 参数:{'query': 'SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC'}\n",
|
|
|
+ " 工具返回:[('技术部', 2), ('销售部', 2), ('人力资源', 1)]\n",
|
|
|
+ "\n",
|
|
|
+ "最终回答:查询完成,以下是 **各部门人数统计**:\n",
|
|
|
+ "\n",
|
|
|
+ "| 部门 | 人数 |\n",
|
|
|
+ "|------|:----:|\n",
|
|
|
+ "| 🛠️ 技术部 | 2 人 |\n",
|
|
|
+ "| 📊 销售部 | 2 人 |\n",
|
|
|
+ "| 👥 人力资源 | 1 人 |\n",
|
|
|
+ "\n",
|
|
|
+ "**业务洞察:** 目前公司共有 5 名员工,技术部和销售部各占 2 人,人力资源部 1 人。技术部和销售部人员配置相对均衡。\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# 遍历所有消息,还原 Agent 的完整推理链\n",
|
|
|
+ "response = sql_agent.invoke({\n",
|
|
|
+ " \"messages\": [{\"role\": \"user\", \"content\": \"每个部门各有多少人?\"}]\n",
|
|
|
+ "})\n",
|
|
|
+ "\n",
|
|
|
+ "for i, msg in enumerate(response[\"messages\"]):\n",
|
|
|
+ " msg_type = msg.__class__.__name__\n",
|
|
|
+ "\n",
|
|
|
+ " # AIMessage 中如果有 tool_calls,说明 Agent 调用了工具\n",
|
|
|
+ " if hasattr(msg, \"tool_calls\") and msg.tool_calls:\n",
|
|
|
+ " print(f\"\\n步骤 {i+1} [{msg_type}]:\")\n",
|
|
|
+ " for tc in msg.tool_calls:\n",
|
|
|
+ " print(f\" 调用工具:{tc['name']}\")\n",
|
|
|
+ " print(f\" 参数:{tc.get('args', {})}\")\n",
|
|
|
+ "\n",
|
|
|
+ " # ToolMessage 是工具返回的结果\n",
|
|
|
+ " elif msg_type == \"ToolMessage\":\n",
|
|
|
+ " print(f\" 工具返回:{msg.content[:200]}\")\n",
|
|
|
+ "\n",
|
|
|
+ " # AIMessage 的最终文本回答\n",
|
|
|
+ " elif msg.content:\n",
|
|
|
+ " print(f\"\\n最终回答:{msg.content}\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "019a2781",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import pandas as pd\n",
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
+ "import seaborn as sns\n",
|
|
|
+ "import numpy as np\n",
|
|
|
+ "\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 配置 matplotlib 中文显示\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# Windows 用 SimHei(黑体),macOS 用 PingFang SC\n",
|
|
|
+ "# 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存\n",
|
|
|
+ "plt.rcParams[\"font.sans-serif\"] = [\"SimHei\", \"PingFang SC\", \"DejaVu Sans\"]\n",
|
|
|
+ "plt.rcParams[\"axes.unicode_minus\"] = False # 解决负号显示为方块的问题\n",
|
|
|
+ "\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 从数据库加载数据到 Pandas DataFrame\n",
|
|
|
+ "# ============================================================\n",
|
|
|
+ "# 用 SQLAlchemy engine 复用连接,避免重复创建连接池\n",
|
|
|
+ "from sqlalchemy import create_engine\n",
|
|
|
+ "\n",
|
|
|
+ "engine = create_engine(db_uri)\n",
|
|
|
+ "\n",
|
|
|
+ "employees_df = pd.read_sql(\"SELECT * FROM employees\", engine)\n",
|
|
|
+ "products_df = pd.read_sql(\"SELECT * FROM products\", engine)\n",
|
|
|
+ "orders_df = pd.read_sql(\"SELECT * FROM orders\", engine)\n",
|
|
|
+ "\n",
|
|
|
+ "print(\"✅ 数据加载完成\")\n",
|
|
|
+ "print(f\" 员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列\")\n",
|
|
|
+ "print(f\" 产品表:{len(products_df)} 行 × {len(products_df.columns)} 列\")\n",
|
|
|
+ "print(f\" 订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列\")\n",
|
|
|
+ "print(f\"\\n📋 员工表示例:\")\n",
|
|
|
+ "print(employees_df.head(3).to_string(index=False))"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "analytics_demo",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.11.15"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|