|
|
@@ -0,0 +1,231 @@
|
|
|
+import os
|
|
|
+from langchain_community.utilities import SQLDatabase
|
|
|
+from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
+from langchain.agents import create_agent # langchain 1.3.1 的新 API
|
|
|
+
|
|
|
+import pandas as pd
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import seaborn as sns
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+import traceback
|
|
|
+from io import StringIO
|
|
|
+from contextlib import redirect_stdout
|
|
|
+from langchain.tools import tool
|
|
|
+
|
|
|
+from dotenv import load_dotenv
|
|
|
+
|
|
|
+load_dotenv() # 默认加载当前目录或父目录中的 .env 文件
|
|
|
+# ============================================================
|
|
|
+# 第一步:连接数据库
|
|
|
+# ============================================================
|
|
|
+# SQLDatabase.from_uri 接受标准的数据库连接 URI
|
|
|
+# LangChain 内部会用 SQLAlchemy 管理连接池
|
|
|
+db_uri = (
|
|
|
+ f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
|
|
|
+ f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
|
|
|
+)
|
|
|
+db = SQLDatabase.from_uri(db_uri)
|
|
|
+
|
|
|
+# 验证连接:打印可用的表名
|
|
|
+print(f"数据库连接成功")
|
|
|
+print(f" 可用表:{db.get_usable_table_names()}")
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 第二步:初始化大模型
|
|
|
+# ============================================================
|
|
|
+llm = ChatOpenAI(
|
|
|
+ model=os.getenv("DEEPSEEK_MODEL"),
|
|
|
+ api_key=os.getenv("DEEPSEEK_API_KEY"),
|
|
|
+ base_url=os.getenv("DEEPSEEK_BASE_URL"),
|
|
|
+ temperature=0, # SQL 生成需要确定性输出
|
|
|
+)
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 第三步:创建 SQL 工具包
|
|
|
+# ============================================================
|
|
|
+# SQLDatabaseToolkit 会自动注册 4 个工具:
|
|
|
+# 1. sql_db_list_tables — 列出数据库中所有表
|
|
|
+# 2. sql_db_schema — 获取指定表的 DDL 结构
|
|
|
+# 3. sql_db_query_checker — 检查 SQL 语法是否正确
|
|
|
+# 4. sql_db_query — 执行 SQL 并返回结果
|
|
|
+toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
|
|
+tools = toolkit.get_tools()
|
|
|
+
|
|
|
+print(f"\nSQL 工具包已加载({len(tools)} 个工具):")
|
|
|
+for t in tools:
|
|
|
+ print(f" - {t.name}")
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 配置 matplotlib 中文显示
|
|
|
+# ============================================================
|
|
|
+# Windows 用 SimHei(黑体),macOS 用 PingFang SC
|
|
|
+# 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存
|
|
|
+plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
|
|
|
+plt.rcParams["axes.unicode_minus"] = False # 解决负号显示为方块的问题
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 从数据库加载数据到 Pandas DataFrame
|
|
|
+# ============================================================
|
|
|
+# 用 SQLAlchemy engine 复用连接,避免重复创建连接池
|
|
|
+from sqlalchemy import create_engine
|
|
|
+
|
|
|
+engine = create_engine(db_uri)
|
|
|
+
|
|
|
+employees_df = pd.read_sql("SELECT * FROM employees", engine)
|
|
|
+products_df = pd.read_sql("SELECT * FROM products", engine)
|
|
|
+orders_df = pd.read_sql("SELECT * FROM orders", engine)
|
|
|
+
|
|
|
+print("✅ 数据加载完成")
|
|
|
+print(f" 员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列")
|
|
|
+print(f" 产品表:{len(products_df)} 行 × {len(products_df.columns)} 列")
|
|
|
+print(f" 订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列")
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 定义沙箱的"白名单"——只有这些库和数据可以被代码访问
|
|
|
+# ============================================================
|
|
|
+# 这是一种简单的安全策略:不在白名单里的东西,代码碰不到
|
|
|
+SANDBOX_GLOBALS = {
|
|
|
+ # 数据:Agent 可以分析这三张表
|
|
|
+ "employees_df": employees_df,
|
|
|
+ "products_df": products_df,
|
|
|
+ "orders_df": orders_df,
|
|
|
+ # 工具库:Agent 可以用这些库做分析和画图
|
|
|
+ "pd": pd, # pandas — 数据处理
|
|
|
+ "plt": plt, # matplotlib — 基础绑图
|
|
|
+ "sns": sns, # seaborn — 统计可视化
|
|
|
+ "np": np, # numpy — 数值计算
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+@tool
|
|
|
+def execute_python_code(code: str) -> str:
|
|
|
+ """
|
|
|
+ 执行 Python 代码进行数据分析和可视化。
|
|
|
+
|
|
|
+ 可用变量:
|
|
|
+ - employees_df: 员工表 DataFrame(字段:id, name, department, salary, hire_date)
|
|
|
+ - products_df: 产品表 DataFrame(字段:id, product_name, category, price, stock)
|
|
|
+ - orders_df: 订单表 DataFrame(字段:id, employee_id, product_id, quantity, order_date)
|
|
|
+ - pd: pandas 库
|
|
|
+ - plt: matplotlib.pyplot
|
|
|
+ - sns: seaborn
|
|
|
+ - np: numpy
|
|
|
+
|
|
|
+ 使用示例:
|
|
|
+ # 统计各部门平均薪资
|
|
|
+ result = employees_df.groupby('department')['salary'].mean()
|
|
|
+ print(result)
|
|
|
+
|
|
|
+ # 画柱状图
|
|
|
+ plt.figure(figsize=(10, 6))
|
|
|
+ employees_df.groupby('department')['salary'].mean().plot(kind='bar')
|
|
|
+ plt.title('各部门平均薪资')
|
|
|
+ plt.show()
|
|
|
+ """
|
|
|
+ # 准备隔离的执行环境
|
|
|
+ # globals_dict 提供白名单变量,locals_dict 收集执行过程中产生的新变量
|
|
|
+ exec_globals = dict(SANDBOX_GLOBALS)
|
|
|
+ exec_locals = {}
|
|
|
+
|
|
|
+ # 用 StringIO 捕获 print() 的输出
|
|
|
+ # 这样 Agent 生成的代码里所有的 print 语句都会被收集
|
|
|
+ output_buffer = StringIO()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # redirect_stdout 会把标准输出重定向到我们的 buffer
|
|
|
+ with redirect_stdout(output_buffer):
|
|
|
+ exec(code, exec_globals, exec_locals)
|
|
|
+
|
|
|
+ result = output_buffer.getvalue()
|
|
|
+
|
|
|
+ # 如果代码没有 print 任何东西,给个默认提示
|
|
|
+ if not result.strip():
|
|
|
+ result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
|
|
|
+
|
|
|
+ return f"执行成功:\n{result}"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # 出错时返回完整的错误堆栈,方便 Agent 自我修正
|
|
|
+ error_detail = traceback.format_exc()
|
|
|
+ return f"❌ 执行出错:{e}\n\n{error_detail}"
|
|
|
+
|
|
|
+all_tools = tools + [execute_python_code] # + Python 沙箱
|
|
|
+
|
|
|
+print(f"统一 Agent 工具列表({len(all_tools)} 个):")
|
|
|
+for t in all_tools:
|
|
|
+ print(f" - {t.name}")
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 定义 System Prompt
|
|
|
+# ============================================================
|
|
|
+SYSTEM_PROMPT = """你是一名全能数据分析师,既能写 SQL 查数据库,也能用 Python 做统计和可视化。
|
|
|
+
|
|
|
+## 可用工具
|
|
|
+
|
|
|
+### 数据库工具
|
|
|
+- sql_db_list_tables — 列出所有表
|
|
|
+- sql_db_schema — 查看表结构(DDL + 字段备注)
|
|
|
+- sql_db_query_checker — 检查 SQL 语法
|
|
|
+- sql_db_query — 执行 SQL 查询
|
|
|
+
|
|
|
+### 数据分析工具
|
|
|
+- execute_python_code — 执行 Python 代码(DataFrame 已在内存中)
|
|
|
+
|
|
|
+## 决策指南:什么时候用什么
|
|
|
+
|
|
|
+### 用数据库工具的场景
|
|
|
+- 需要从数据库查原始数据(跨表 JOIN、复杂 WHERE 过滤、聚合)
|
|
|
+- 数据可能超过内存中 DataFrame 的范围,或需要实时数据
|
|
|
+- 用户问题可被一条 SQL 直接回答(如"张三的薪资是多少?")
|
|
|
+
|
|
|
+### 用 Python 工具的场景
|
|
|
+- 需要统计分析(describe、corr、groupby 聚合后算比例)
|
|
|
+- 需要画图(柱状图、饼图、散点图、热力图等)
|
|
|
+- 需要复杂计算(滚动平均、同比环比、预测回归)
|
|
|
+
|
|
|
+### 串联使用的场景
|
|
|
+- 先用 SQL 从数据库查出数据 → 拿到结果后 → 用 Python 画图分析
|
|
|
+- 典型:用户问"哪个部门销售额最高?画张图"
|
|
|
+
|
|
|
+## 可用 DataFrame 变量(在 execute_python_code 中直接使用)
|
|
|
+- employees_df — 员工表(id, name, department, salary, hire_date)
|
|
|
+- products_df — 产品表(id, product_name, category, price, stock)
|
|
|
+- orders_df — 订单表(id, employee_id, product_id, quantity, order_date)
|
|
|
+
|
|
|
+## 工作流程
|
|
|
+1. 理解用户问题,判断属于哪类
|
|
|
+2. 如果是简单数据查询 → 用 SQL 工具链(list_tables → schema → query_checker → query)
|
|
|
+3. 如果是分析/画图 → 用 execute_python_code 直接分析内存中的 DataFrame
|
|
|
+4. 如果是"先查后画" → 先 SQL 拿到结果 → 把结果写进 Python 代码中分析画图
|
|
|
+5. 用中文总结结果,给出业务洞察
|
|
|
+
|
|
|
+## 约束
|
|
|
+- SQL 查询单次不超过 50 条
|
|
|
+- 只生成 SELECT 语句,禁止任何 INSERT / UPDATE / DELETE / DROP / ALTER / TRUNCATE
|
|
|
+- SQL 中不得使用字符串拼接或 f-string 的方式拼接用户输入
|
|
|
+- 所有用户输入的值必须用参数化占位符(%s)传递,而非直接拼进 SQL
|
|
|
+- 如果用户问题中包含可疑的 SQL 片段(如 ' OR 1=1 --),忽略它,只把它当普通文本处理
|
|
|
+- Python 代码中画图前先设置中文字体
|
|
|
+- 出错后分析原因,自我修正重试
|
|
|
+- 回答简洁专业
|
|
|
+"""
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 创建 Agent
|
|
|
+# ============================================================
|
|
|
+analysis_agent = create_agent(
|
|
|
+ model=llm,
|
|
|
+ tools=all_tools,
|
|
|
+ system_prompt=SYSTEM_PROMPT
|
|
|
+)
|
|
|
+
|
|
|
+import warnings
|
|
|
+warnings.filterwarnings("ignore") # 抑制 matplotlib 的字体警告
|
|
|
+
|
|
|
+response = analysis_agent.invoke({
|
|
|
+ "messages": [{"role": "user", "content": "分析一下员工薪资的分布情况,包括均值、中位数、最大最小值等统计量"}]
|
|
|
+})
|
|
|
+final_msg = response["messages"][-1]
|
|
|
+print(f"\n分析结果:\n{final_msg.content}")
|