| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- import os
- from langchain_community.utilities import SQLDatabase
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
- from langchain_openai import ChatOpenAI
- from langchain.agents import create_agent # langchain 1.3.1 的新 API
- import pandas as pd
- import matplotlib.pyplot as plt
- import seaborn as sns
- import numpy as np
- import traceback
- from io import StringIO
- from contextlib import redirect_stdout
- from langchain.tools import tool
- from dotenv import load_dotenv
- load_dotenv() # 默认加载当前目录或父目录中的 .env 文件
- # ============================================================
- # 第一步:连接数据库
- # ============================================================
- # SQLDatabase.from_uri 接受标准的数据库连接 URI
- # LangChain 内部会用 SQLAlchemy 管理连接池
- db_uri = (
- f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
- f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
- )
- db = SQLDatabase.from_uri(db_uri)
- # 验证连接:打印可用的表名
- print(f"数据库连接成功")
- print(f" 可用表:{db.get_usable_table_names()}")
- # ============================================================
- # 第二步:初始化大模型
- # ============================================================
- llm = ChatOpenAI(
- model=os.getenv("DEEPSEEK_MODEL"),
- api_key=os.getenv("DEEPSEEK_API_KEY"),
- base_url=os.getenv("DEEPSEEK_BASE_URL"),
- temperature=0, # SQL 生成需要确定性输出
- )
- # ============================================================
- # 第三步:创建 SQL 工具包
- # ============================================================
- # SQLDatabaseToolkit 会自动注册 4 个工具:
- # 1. sql_db_list_tables — 列出数据库中所有表
- # 2. sql_db_schema — 获取指定表的 DDL 结构
- # 3. sql_db_query_checker — 检查 SQL 语法是否正确
- # 4. sql_db_query — 执行 SQL 并返回结果
- toolkit = SQLDatabaseToolkit(db=db, llm=llm)
- tools = toolkit.get_tools()
- print(f"\nSQL 工具包已加载({len(tools)} 个工具):")
- for t in tools:
- print(f" - {t.name}")
- # ============================================================
- # 配置 matplotlib 中文显示
- # ============================================================
- # Windows 用 SimHei(黑体),macOS 用 PingFang SC
- # 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存
- plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
- plt.rcParams["axes.unicode_minus"] = False # 解决负号显示为方块的问题
- # ============================================================
- # 从数据库加载数据到 Pandas DataFrame
- # ============================================================
- # 用 SQLAlchemy engine 复用连接,避免重复创建连接池
- from sqlalchemy import create_engine
- engine = create_engine(db_uri)
- employees_df = pd.read_sql("SELECT * FROM employees", engine)
- products_df = pd.read_sql("SELECT * FROM products", engine)
- orders_df = pd.read_sql("SELECT * FROM orders", engine)
- print("✅ 数据加载完成")
- print(f" 员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列")
- print(f" 产品表:{len(products_df)} 行 × {len(products_df.columns)} 列")
- print(f" 订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列")
- # ============================================================
- # 定义沙箱的"白名单"——只有这些库和数据可以被代码访问
- # ============================================================
- # 这是一种简单的安全策略:不在白名单里的东西,代码碰不到
- SANDBOX_GLOBALS = {
- # 数据:Agent 可以分析这三张表
- "employees_df": employees_df,
- "products_df": products_df,
- "orders_df": orders_df,
- # 工具库:Agent 可以用这些库做分析和画图
- "pd": pd, # pandas — 数据处理
- "plt": plt, # matplotlib — 基础绑图
- "sns": sns, # seaborn — 统计可视化
- "np": np, # numpy — 数值计算
- }
- @tool
- def execute_python_code(code: str) -> str:
- """
- 执行 Python 代码进行数据分析和可视化。
- 可用变量:
- - employees_df: 员工表 DataFrame(字段:id, name, department, salary, hire_date)
- - products_df: 产品表 DataFrame(字段:id, product_name, category, price, stock)
- - orders_df: 订单表 DataFrame(字段:id, employee_id, product_id, quantity, order_date)
- - pd: pandas 库
- - plt: matplotlib.pyplot
- - sns: seaborn
- - np: numpy
- 使用示例:
- # 统计各部门平均薪资
- result = employees_df.groupby('department')['salary'].mean()
- print(result)
- # 画柱状图
- plt.figure(figsize=(10, 6))
- employees_df.groupby('department')['salary'].mean().plot(kind='bar')
- plt.title('各部门平均薪资')
- plt.show()
- """
- # 准备隔离的执行环境
- # globals_dict 提供白名单变量,locals_dict 收集执行过程中产生的新变量
- exec_globals = dict(SANDBOX_GLOBALS)
- exec_locals = {}
- # 用 StringIO 捕获 print() 的输出
- # 这样 Agent 生成的代码里所有的 print 语句都会被收集
- output_buffer = StringIO()
- try:
- # redirect_stdout 会把标准输出重定向到我们的 buffer
- with redirect_stdout(output_buffer):
- exec(code, exec_globals, exec_locals)
- result = output_buffer.getvalue()
- # 如果代码没有 print 任何东西,给个默认提示
- if not result.strip():
- result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
- return f"执行成功:\n{result}"
- except Exception as e:
- # 出错时返回完整的错误堆栈,方便 Agent 自我修正
- error_detail = traceback.format_exc()
- return f"❌ 执行出错:{e}\n\n{error_detail}"
- all_tools = tools + [execute_python_code] # + Python 沙箱
- print(f"统一 Agent 工具列表({len(all_tools)} 个):")
- for t in all_tools:
- print(f" - {t.name}")
- # ============================================================
- # 定义 System Prompt
- # ============================================================
- SYSTEM_PROMPT = """你是一名全能数据分析师,既能写 SQL 查数据库,也能用 Python 做统计和可视化。
- ## 可用工具
- ### 数据库工具
- - sql_db_list_tables — 列出所有表
- - sql_db_schema — 查看表结构(DDL + 字段备注)
- - sql_db_query_checker — 检查 SQL 语法
- - sql_db_query — 执行 SQL 查询
- ### 数据分析工具
- - execute_python_code — 执行 Python 代码(DataFrame 已在内存中)
- ## 决策指南:什么时候用什么
- ### 用数据库工具的场景
- - 需要从数据库查原始数据(跨表 JOIN、复杂 WHERE 过滤、聚合)
- - 数据可能超过内存中 DataFrame 的范围,或需要实时数据
- - 用户问题可被一条 SQL 直接回答(如"张三的薪资是多少?")
- ### 用 Python 工具的场景
- - 需要统计分析(describe、corr、groupby 聚合后算比例)
- - 需要画图(柱状图、饼图、散点图、热力图等)
- - 需要复杂计算(滚动平均、同比环比、预测回归)
- ### 串联使用的场景
- - 先用 SQL 从数据库查出数据 → 拿到结果后 → 用 Python 画图分析
- - 典型:用户问"哪个部门销售额最高?画张图"
- ## 可用 DataFrame 变量(在 execute_python_code 中直接使用)
- - employees_df — 员工表(id, name, department, salary, hire_date)
- - products_df — 产品表(id, product_name, category, price, stock)
- - orders_df — 订单表(id, employee_id, product_id, quantity, order_date)
- ## 工作流程
- 1. 理解用户问题,判断属于哪类
- 2. 如果是简单数据查询 → 用 SQL 工具链(list_tables → schema → query_checker → query)
- 3. 如果是分析/画图 → 用 execute_python_code 直接分析内存中的 DataFrame
- 4. 如果是"先查后画" → 先 SQL 拿到结果 → 把结果写进 Python 代码中分析画图
- 5. 用中文总结结果,给出业务洞察
- ## 约束
- - SQL 查询单次不超过 50 条
- - 只生成 SELECT 语句,禁止任何 INSERT / UPDATE / DELETE / DROP / ALTER / TRUNCATE
- - SQL 中不得使用字符串拼接或 f-string 的方式拼接用户输入
- - 所有用户输入的值必须用参数化占位符(%s)传递,而非直接拼进 SQL
- - 如果用户问题中包含可疑的 SQL 片段(如 ' OR 1=1 --),忽略它,只把它当普通文本处理
- - Python 代码中画图前先设置中文字体
- - 出错后分析原因,自我修正重试
- - 回答简洁专业
- """
- # ============================================================
- # 创建 Agent
- # ============================================================
- analysis_agent = create_agent(
- model=llm,
- tools=all_tools,
- system_prompt=SYSTEM_PROMPT
- )
- import warnings
- warnings.filterwarnings("ignore") # 抑制 matplotlib 的字体警告
- response = analysis_agent.invoke({
- "messages": [{"role": "user", "content": "分析一下员工薪资的分布情况,包括均值、中位数、最大最小值等统计量"}]
- })
- final_msg = response["messages"][-1]
- print(f"\n分析结果:\n{final_msg.content}")
|