{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "2671054f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\34450\\AppData\\Local\\Temp\\ipykernel_16268\\3464868009.py:2: DeprecationWarning: `langchain-community` is being sunset and is no longer actively maintained. See https://github.com/langchain-ai/langchain-community/issues/674 for details and migration guidance toward standalone integration packages.\n", " from langchain_community.utilities import SQLDatabase\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "数据库连接成功\n", " 可用表:['employees', 'orders', 'products']\n", "\n", "SQL 工具包已加载(4 个工具):\n", " - sql_db_query\n", " - sql_db_schema\n", " - sql_db_list_tables\n", " - sql_db_query_checker\n", "\n", "NL2SQL Agent 创建完成,可以开始提问了!\n" ] } ], "source": [ "import os\n", "from langchain_community.utilities import SQLDatabase\n", "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n", "from langchain_openai import ChatOpenAI\n", "from langchain.agents import create_agent # langchain 1.3.1 的新 API\n", "from dotenv import load_dotenv \n", "\n", "load_dotenv() # 默认加载当前目录或父目录中的 .env 文件\n", "# ============================================================\n", "# 第一步:连接数据库\n", "# ============================================================\n", "# SQLDatabase.from_uri 接受标准的数据库连接 URI\n", "# LangChain 内部会用 SQLAlchemy 管理连接池\n", "db_uri = (\n", " f\"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}\"\n", " f\"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}\"\n", ")\n", "db = SQLDatabase.from_uri(db_uri)\n", "\n", "# 验证连接:打印可用的表名\n", "print(f\"数据库连接成功\")\n", "print(f\" 可用表:{db.get_usable_table_names()}\")\n", "\n", "# ============================================================\n", "# 第二步:初始化大模型\n", "# ============================================================\n", "llm = ChatOpenAI(\n", " model=os.getenv(\"DEEPSEEK_MODEL\"),\n", " api_key=os.getenv(\"DEEPSEEK_API_KEY\"),\n", " base_url=os.getenv(\"DEEPSEEK_BASE_URL\"),\n", " temperature=0, # SQL 生成需要确定性输出\n", ")\n", "\n", "# ============================================================\n", "# 第三步:创建 SQL 工具包\n", "# ============================================================\n", "# SQLDatabaseToolkit 会自动注册 4 个工具:\n", "# 1. sql_db_list_tables — 列出数据库中所有表\n", "# 2. sql_db_schema — 获取指定表的 DDL 结构\n", "# 3. sql_db_query_checker — 检查 SQL 语法是否正确\n", "# 4. sql_db_query — 执行 SQL 并返回结果\n", "toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n", "tools = toolkit.get_tools()\n", "\n", "print(f\"\\nSQL 工具包已加载({len(tools)} 个工具):\")\n", "for t in tools:\n", " print(f\" - {t.name}\")\n", "\n", "# ============================================================\n", "# 第四步:定义 System Prompt(Agent 的\"岗位说明书\")\n", "# ============================================================\n", "SQL_AGENT_PROMPT = \"\"\"你是一名专业的 SQL 数据分析师。\n", "\n", "## 工作流程\n", "1. 先用 sql_db_list_tables 查看数据库中有哪些表\n", "2. 用 sql_db_schema 获取相关表的字段结构和类型\n", "3. 生成 SQL 之前,用 sql_db_query_checker 检查语法\n", "4. 确认无误后,用 sql_db_query 执行查询\n", "5. 用中文总结查询结果,给出简洁的业务洞察\n", "\n", "## 约束\n", "- 只使用数据库中实际存在的表和字段,不要凭空编造\n", "- 单次查询结果限制在 50 条以内\n", "- 如果查询出错,分析错误原因后重新生成 SQL\n", "- 回答要简洁专业,不要啰嗦\n", "\"\"\"\n", "\n", "# ============================================================\n", "# 第五步:组装 Agent(langchain 1.3.1 一行搞定)\n", "# ============================================================\n", "# create_agent 返回一个可直接调用的 Runnable,无需再套 AgentExecutor\n", "# 它内部自动处理工具调用、循环迭代、错误重试等逻辑\n", "sql_agent = create_agent(\n", " model=llm,\n", " tools=tools,\n", " system_prompt=SQL_AGENT_PROMPT,\n", ")\n", "\n", "print(\"\\nNL2SQL Agent 创建完成,可以开始提问了!\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "bd1fef44", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "回答:公司共有 **5 名员工**。\n" ] } ], "source": [ "# 简单问题:Agent 只需一条 COUNT SQL 就能搞定\n", "# create_agent 的 invoke 接口使用 messages 格式\n", "response = sql_agent.invoke({\n", " \"messages\": [{\"role\": \"user\", \"content\": \"公司一共有多少名员工?\"}]\n", "})\n", "\n", "# 最终回答在最后一条消息里\n", "final_msg = response[\"messages\"][-1]\n", "print(f\"回答:{final_msg.content}\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "a7025030", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "回答:## 查询结果\n", "\n", "技术部共有 **2 名员工**,按薪资从高到低排序如下:\n", "\n", "| 姓名 | 薪资 |\n", "|------|------|\n", "| 张三 | 20,000.00 |\n", "| 王五 | 16,000.00 |\n", "\n", "**业务洞察:**\n", "- 技术部员工平均薪资为 **18,000 元**。\n", "- 张三薪资最高(20,000元),比第二名王五高出 **25%**。\n" ] } ], "source": [ "# 中等难度:需要 JOIN + WHERE + GROUP BY\n", "response = sql_agent.invoke({\n", " \"messages\": [{\"role\": \"user\", \"content\": \"列出技术部所有员工的姓名和薪资,按薪资从高到低排序\"}]\n", "})\n", "final_msg = response[\"messages\"][-1]\n", "print(f\"回答:{final_msg.content}\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "cb41a941", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "回答:查询结果如下:\n", "\n", "---\n", "\n", "### 📊 销售部订单统计\n", "\n", "| 指标 | 数据 |\n", "|------|------|\n", "| **订单总数** | **3 笔** |\n", "| **订单总金额** | **¥14,979.00** |\n", "\n", "### 👥 员工明细\n", "\n", "| 员工 | 订单数 | 金额 |\n", "|-----|:-----:|:----:|\n", "| **李四** | 2 笔 | ¥11,985.00 |\n", "| **钱七** | 1 笔 | ¥2,994.00 |\n", "\n", "**业务洞察:**\n", "- 销售部共 **2 名员工**,累计下单 **3 笔**,总金额 **14,979 元**。\n", "- 李四贡献了总金额的 **80%**(11,985 元),是销售部的核心业务人员。\n" ] } ], "source": [ "# 高难度:需要 JOIN 三张表 + 聚合计算\n", "response = sql_agent.invoke({\n", " \"messages\": [{\"role\": \"user\", \"content\": \"销售部的员工总共下了多少订单?订单总金额是多少?\"}]\n", "})\n", "final_msg = response[\"messages\"][-1]\n", "print(f\"回答:{final_msg.content}\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "5ce44c51", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "最终回答:每个部门各有多少人?\n", "\n", "步骤 2 [AIMessage]:\n", " 调用工具:sql_db_list_tables\n", " 参数:{}\n", " 工具返回:employees, orders, products\n", "\n", "步骤 4 [AIMessage]:\n", " 调用工具:sql_db_schema\n", " 参数:{'table_names': 'employees'}\n", " 工具返回:\n", "CREATE TABLE employees (\n", "\tid INTEGER NOT NULL AUTO_INCREMENT, \n", "\tname VARCHAR(50) NOT NULL, \n", "\tdepartment VARCHAR(50) NOT NULL, \n", "\tsalary DECIMAL(10, 2) NOT NULL, \n", "\thire_date DATE, \n", "\tPRIMARY KEY (id)\n", ")E\n", "\n", "步骤 6 [AIMessage]:\n", " 调用工具:sql_db_query_checker\n", " 参数:{'query': 'SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC'}\n", " 工具返回:SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC\n", "\n", "步骤 8 [AIMessage]:\n", " 调用工具:sql_db_query\n", " 参数:{'query': 'SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC'}\n", " 工具返回:[('技术部', 2), ('销售部', 2), ('人力资源', 1)]\n", "\n", "最终回答:查询完成,以下是 **各部门人数统计**:\n", "\n", "| 部门 | 人数 |\n", "|------|:----:|\n", "| 🛠️ 技术部 | 2 人 |\n", "| 📊 销售部 | 2 人 |\n", "| 👥 人力资源 | 1 人 |\n", "\n", "**业务洞察:** 目前公司共有 5 名员工,技术部和销售部各占 2 人,人力资源部 1 人。技术部和销售部人员配置相对均衡。\n" ] } ], "source": [ "# 遍历所有消息,还原 Agent 的完整推理链\n", "response = sql_agent.invoke({\n", " \"messages\": [{\"role\": \"user\", \"content\": \"每个部门各有多少人?\"}]\n", "})\n", "\n", "for i, msg in enumerate(response[\"messages\"]):\n", " msg_type = msg.__class__.__name__\n", "\n", " # AIMessage 中如果有 tool_calls,说明 Agent 调用了工具\n", " if hasattr(msg, \"tool_calls\") and msg.tool_calls:\n", " print(f\"\\n步骤 {i+1} [{msg_type}]:\")\n", " for tc in msg.tool_calls:\n", " print(f\" 调用工具:{tc['name']}\")\n", " print(f\" 参数:{tc.get('args', {})}\")\n", "\n", " # ToolMessage 是工具返回的结果\n", " elif msg_type == \"ToolMessage\":\n", " print(f\" 工具返回:{msg.content[:200]}\")\n", "\n", " # AIMessage 的最终文本回答\n", " elif msg.content:\n", " print(f\"\\n最终回答:{msg.content}\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "019a2781", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ 数据加载完成\n", " 员工表:5 行 × 5 列\n", " 产品表:4 行 × 5 列\n", " 订单表:5 行 × 5 列\n", "\n", "📋 员工表示例:\n", " id name department salary hire_date\n", " 1 张三 技术部 20000.0 2023-01-15\n", " 2 李四 销售部 11000.0 2023-02-20\n", " 3 王五 技术部 16000.0 2022-11-10\n" ] } ], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import numpy as np\n", "\n", "# ============================================================\n", "# 配置 matplotlib 中文显示\n", "# ============================================================\n", "# Windows 用 SimHei(黑体),macOS 用 PingFang SC\n", "# 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存\n", "plt.rcParams[\"font.sans-serif\"] = [\"SimHei\", \"PingFang SC\", \"DejaVu Sans\"]\n", "plt.rcParams[\"axes.unicode_minus\"] = False # 解决负号显示为方块的问题\n", "\n", "# ============================================================\n", "# 从数据库加载数据到 Pandas DataFrame\n", "# ============================================================\n", "# 用 SQLAlchemy engine 复用连接,避免重复创建连接池\n", "from sqlalchemy import create_engine\n", "\n", "engine = create_engine(db_uri)\n", "\n", "employees_df = pd.read_sql(\"SELECT * FROM employees\", engine)\n", "products_df = pd.read_sql(\"SELECT * FROM products\", engine)\n", "orders_df = pd.read_sql(\"SELECT * FROM orders\", engine)\n", "\n", "print(\"✅ 数据加载完成\")\n", "print(f\" 员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列\")\n", "print(f\" 产品表:{len(products_df)} 行 × {len(products_df.columns)} 列\")\n", "print(f\" 订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列\")\n", "print(f\"\\n📋 员工表示例:\")\n", "print(employees_df.head(3).to_string(index=False))" ] }, { "cell_type": "code", "execution_count": 8, "id": "4c043990", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ Python 代码执行沙箱创建成功\n" ] } ], "source": [ "import traceback\n", "from io import StringIO\n", "from contextlib import redirect_stdout\n", "from langchain.tools import tool\n", "\n", "# ============================================================\n", "# 定义沙箱的\"白名单\"——只有这些库和数据可以被代码访问\n", "# ============================================================\n", "# 这是一种简单的安全策略:不在白名单里的东西,代码碰不到\n", "SANDBOX_GLOBALS = {\n", " # 数据:Agent 可以分析这三张表\n", " \"employees_df\": employees_df,\n", " \"products_df\": products_df,\n", " \"orders_df\": orders_df,\n", " # 工具库:Agent 可以用这些库做分析和画图\n", " \"pd\": pd, # pandas — 数据处理\n", " \"plt\": plt, # matplotlib — 基础绑图\n", " \"sns\": sns, # seaborn — 统计可视化\n", " \"np\": np, # numpy — 数值计算\n", "}\n", "\n", "\n", "@tool\n", "def execute_python_code(code: str) -> str:\n", " \"\"\"\n", " 执行 Python 代码进行数据分析和可视化。\n", "\n", " 可用变量:\n", " - employees_df: 员工表 DataFrame(字段:id, name, department, salary, hire_date)\n", " - products_df: 产品表 DataFrame(字段:id, product_name, category, price, stock)\n", " - orders_df: 订单表 DataFrame(字段:id, employee_id, product_id, quantity, order_date)\n", " - pd: pandas 库\n", " - plt: matplotlib.pyplot\n", " - sns: seaborn\n", " - np: numpy\n", "\n", " 使用示例:\n", " # 统计各部门平均薪资\n", " result = employees_df.groupby('department')['salary'].mean()\n", " print(result)\n", "\n", " # 画柱状图\n", " plt.figure(figsize=(10, 6))\n", " employees_df.groupby('department')['salary'].mean().plot(kind='bar')\n", " plt.title('各部门平均薪资')\n", " plt.show()\n", " \"\"\"\n", " # 准备隔离的执行环境\n", " # globals_dict 提供白名单变量,locals_dict 收集执行过程中产生的新变量\n", " exec_globals = dict(SANDBOX_GLOBALS)\n", " exec_locals = {}\n", "\n", " # 用 StringIO 捕获 print() 的输出\n", " # 这样 Agent 生成的代码里所有的 print 语句都会被收集\n", " output_buffer = StringIO()\n", "\n", " try:\n", " # redirect_stdout 会把标准输出重定向到我们的 buffer\n", " with redirect_stdout(output_buffer):\n", " exec(code, exec_globals, exec_locals)\n", "\n", " result = output_buffer.getvalue()\n", "\n", " # 如果代码没有 print 任何东西,给个默认提示\n", " if not result.strip():\n", " result = \"✅ 代码执行成功(无文本输出,可能已生成图表)\"\n", "\n", " return f\"执行成功:\\n{result}\"\n", "\n", " except Exception as e:\n", " # 出错时返回完整的错误堆栈,方便 Agent 自我修正\n", " error_detail = traceback.format_exc()\n", " return f\"❌ 执行出错:{e}\\n\\n{error_detail}\"\n", "\n", "\n", "print(\"✅ Python 代码执行沙箱创建成功\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "dbeb1792", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "数据可视化 Agent 创建完成\n" ] } ], "source": [ "from langchain.agents import create_agent\n", "\n", "# ============================================================\n", "# 定义可视化 Agent 的 System Prompt\n", "# ============================================================\n", "VISUALIZATION_PROMPT = \"\"\"你是一名资深数据分析师,精通 Python、Pandas 和 Matplotlib 数据可视化。\n", "\n", "## 可用数据\n", "1. employees_df — 员工表(字段:id, name, department, salary, hire_date)\n", "2. products_df — 产品表(字段:id, product_name, category, price, stock)\n", "3. orders_df — 订单表(字段:id, employee_id, product_id, quantity, order_date)\n", "\n", "## 工作流程\n", "1. 理解用户的分析需求\n", "2. 用 execute_python_code 工具编写并执行 Python 代码\n", "3. 先做数据探索(head、describe、info),再做深入分析\n", "4. 用中文解释分析结果,给出业务洞察\n", "\n", "## 代码规范\n", "- 绑图前设置中文字体:plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']\n", "- 设置 plt.rcParams['axes.unicode_minus'] = False\n", "- 图表尺寸统一用 plt.figure(figsize=(10, 6))\n", "- 必须添加标题、坐标轴标签,让图表自解释\n", "- 用 print() 输出关键统计量,不要只画图不说话\n", "- 图表标题用英文(避免渲染问题),但用中文向用户解释结果\n", "\n", "## 注意事项\n", "- 每次只执行一段完整的代码,不要拆成多段\n", "- 先探索数据结构,再做分析——不要上来就画图\n", "- 结果要有业务洞察,不只是\"最大值是 XXX\"\n", "\"\"\"\n", "\n", "# ============================================================\n", "# 组装可视化 Agent(langchain 1.3.1 写法)\n", "# ============================================================\n", "viz_tools = [execute_python_code]\n", "\n", "# 同样用 create_agent 一行搞定,和 SQL Agent 的创建方式完全一致\n", "visualization_agent = create_agent(\n", " model=llm,\n", " tools=viz_tools,\n", " system_prompt=VISUALIZATION_PROMPT,\n", ")\n", "\n", "print(\"数据可视化 Agent 创建完成\")" ] } ], "metadata": { "kernelspec": { "display_name": "analytics_demo", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 5 }