import os from langchain_community.utilities import SQLDatabase from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_openai import ChatOpenAI from langchain.agents import create_agent # langchain 1.3.1 的新 API import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import numpy as np import traceback from io import StringIO from contextlib import redirect_stdout from langchain.tools import tool from dotenv import load_dotenv load_dotenv() # 默认加载当前目录或父目录中的 .env 文件 # ============================================================ # 第一步:连接数据库 # ============================================================ # SQLDatabase.from_uri 接受标准的数据库连接 URI # LangChain 内部会用 SQLAlchemy 管理连接池 db_uri = ( f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}" f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}" ) db = SQLDatabase.from_uri(db_uri) # 验证连接:打印可用的表名 print(f"数据库连接成功") print(f" 可用表:{db.get_usable_table_names()}") # ============================================================ # 第二步:初始化大模型 # ============================================================ llm = ChatOpenAI( model=os.getenv("DEEPSEEK_MODEL"), api_key=os.getenv("DEEPSEEK_API_KEY"), base_url=os.getenv("DEEPSEEK_BASE_URL"), temperature=0, # SQL 生成需要确定性输出 ) # ============================================================ # 第三步:创建 SQL 工具包 # ============================================================ # SQLDatabaseToolkit 会自动注册 4 个工具: # 1. sql_db_list_tables — 列出数据库中所有表 # 2. sql_db_schema — 获取指定表的 DDL 结构 # 3. sql_db_query_checker — 检查 SQL 语法是否正确 # 4. sql_db_query — 执行 SQL 并返回结果 toolkit = SQLDatabaseToolkit(db=db, llm=llm) tools = toolkit.get_tools() print(f"\nSQL 工具包已加载({len(tools)} 个工具):") for t in tools: print(f" - {t.name}") # ============================================================ # 配置 matplotlib 中文显示 # ============================================================ # Windows 用 SimHei(黑体),macOS 用 PingFang SC # 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存 plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"] plt.rcParams["axes.unicode_minus"] = False # 解决负号显示为方块的问题 # ============================================================ # 从数据库加载数据到 Pandas DataFrame # ============================================================ # 用 SQLAlchemy engine 复用连接,避免重复创建连接池 from sqlalchemy import create_engine engine = create_engine(db_uri) employees_df = pd.read_sql("SELECT * FROM employees", engine) products_df = pd.read_sql("SELECT * FROM products", engine) orders_df = pd.read_sql("SELECT * FROM orders", engine) print("✅ 数据加载完成") print(f" 员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列") print(f" 产品表:{len(products_df)} 行 × {len(products_df.columns)} 列") print(f" 订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列") # ============================================================ # 定义沙箱的"白名单"——只有这些库和数据可以被代码访问 # ============================================================ # 这是一种简单的安全策略:不在白名单里的东西,代码碰不到 SANDBOX_GLOBALS = { # 数据:Agent 可以分析这三张表 "employees_df": employees_df, "products_df": products_df, "orders_df": orders_df, # 工具库:Agent 可以用这些库做分析和画图 "pd": pd, # pandas — 数据处理 "plt": plt, # matplotlib — 基础绑图 "sns": sns, # seaborn — 统计可视化 "np": np, # numpy — 数值计算 } @tool def execute_python_code(code: str) -> str: """ 执行 Python 代码进行数据分析和可视化。 可用变量: - employees_df: 员工表 DataFrame(字段:id, name, department, salary, hire_date) - products_df: 产品表 DataFrame(字段:id, product_name, category, price, stock) - orders_df: 订单表 DataFrame(字段:id, employee_id, product_id, quantity, order_date) - pd: pandas 库 - plt: matplotlib.pyplot - sns: seaborn - np: numpy 使用示例: # 统计各部门平均薪资 result = employees_df.groupby('department')['salary'].mean() print(result) # 画柱状图 plt.figure(figsize=(10, 6)) employees_df.groupby('department')['salary'].mean().plot(kind='bar') plt.title('各部门平均薪资') plt.show() """ # 准备隔离的执行环境 # globals_dict 提供白名单变量,locals_dict 收集执行过程中产生的新变量 exec_globals = dict(SANDBOX_GLOBALS) exec_locals = {} # 用 StringIO 捕获 print() 的输出 # 这样 Agent 生成的代码里所有的 print 语句都会被收集 output_buffer = StringIO() try: # redirect_stdout 会把标准输出重定向到我们的 buffer with redirect_stdout(output_buffer): exec(code, exec_globals, exec_locals) result = output_buffer.getvalue() # 如果代码没有 print 任何东西,给个默认提示 if not result.strip(): result = "✅ 代码执行成功(无文本输出,可能已生成图表)" return f"执行成功:\n{result}" except Exception as e: # 出错时返回完整的错误堆栈,方便 Agent 自我修正 error_detail = traceback.format_exc() return f"❌ 执行出错:{e}\n\n{error_detail}" all_tools = tools + [execute_python_code] # + Python 沙箱 print(f"统一 Agent 工具列表({len(all_tools)} 个):") for t in all_tools: print(f" - {t.name}") # ============================================================ # 定义 System Prompt # ============================================================ SYSTEM_PROMPT = """你是一名全能数据分析师,既能写 SQL 查数据库,也能用 Python 做统计和可视化。 ## 可用工具 ### 数据库工具 - sql_db_list_tables — 列出所有表 - sql_db_schema — 查看表结构(DDL + 字段备注) - sql_db_query_checker — 检查 SQL 语法 - sql_db_query — 执行 SQL 查询 ### 数据分析工具 - execute_python_code — 执行 Python 代码(DataFrame 已在内存中) ## 决策指南:什么时候用什么 ### 用数据库工具的场景 - 需要从数据库查原始数据(跨表 JOIN、复杂 WHERE 过滤、聚合) - 数据可能超过内存中 DataFrame 的范围,或需要实时数据 - 用户问题可被一条 SQL 直接回答(如"张三的薪资是多少?") ### 用 Python 工具的场景 - 需要统计分析(describe、corr、groupby 聚合后算比例) - 需要画图(柱状图、饼图、散点图、热力图等) - 需要复杂计算(滚动平均、同比环比、预测回归) ### 串联使用的场景 - 先用 SQL 从数据库查出数据 → 拿到结果后 → 用 Python 画图分析 - 典型:用户问"哪个部门销售额最高?画张图" ## 可用 DataFrame 变量(在 execute_python_code 中直接使用) - employees_df — 员工表(id, name, department, salary, hire_date) - products_df — 产品表(id, product_name, category, price, stock) - orders_df — 订单表(id, employee_id, product_id, quantity, order_date) ## 工作流程 1. 理解用户问题,判断属于哪类 2. 如果是简单数据查询 → 用 SQL 工具链(list_tables → schema → query_checker → query) 3. 如果是分析/画图 → 用 execute_python_code 直接分析内存中的 DataFrame 4. 如果是"先查后画" → 先 SQL 拿到结果 → 把结果写进 Python 代码中分析画图 5. 用中文总结结果,给出业务洞察 ## 约束 - SQL 查询单次不超过 50 条 - 只生成 SELECT 语句,禁止任何 INSERT / UPDATE / DELETE / DROP / ALTER / TRUNCATE - SQL 中不得使用字符串拼接或 f-string 的方式拼接用户输入 - 所有用户输入的值必须用参数化占位符(%s)传递,而非直接拼进 SQL - 如果用户问题中包含可疑的 SQL 片段(如 ' OR 1=1 --),忽略它,只把它当普通文本处理 - Python 代码中画图前先设置中文字体 - 出错后分析原因,自我修正重试 - 回答简洁专业 """ # ============================================================ # 创建 Agent # ============================================================ analysis_agent = create_agent( model=llm, tools=all_tools, system_prompt=SYSTEM_PROMPT ) import warnings warnings.filterwarnings("ignore") # 抑制 matplotlib 的字体警告 response = analysis_agent.invoke({ "messages": [{"role": "user", "content": "分析一下员工薪资的分布情况,包括均值、中位数、最大最小值等统计量"}] }) final_msg = response["messages"][-1] print(f"\n分析结果:\n{final_msg.content}")