01_data_analysis.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import os
  2. from langchain_community.utilities import SQLDatabase
  3. from langchain_community.agent_toolkits import SQLDatabaseToolkit
  4. from langchain_openai import ChatOpenAI
  5. from langchain.agents import create_agent # langchain 1.3.1 的新 API
  6. import pandas as pd
  7. import matplotlib.pyplot as plt
  8. import seaborn as sns
  9. import numpy as np
  10. import traceback
  11. from io import StringIO
  12. from contextlib import redirect_stdout
  13. from langchain.tools import tool
  14. from dotenv import load_dotenv
  15. load_dotenv() # 默认加载当前目录或父目录中的 .env 文件
  16. # ============================================================
  17. # 第一步:连接数据库
  18. # ============================================================
  19. # SQLDatabase.from_uri 接受标准的数据库连接 URI
  20. # LangChain 内部会用 SQLAlchemy 管理连接池
  21. db_uri = (
  22. f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
  23. f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
  24. )
  25. db = SQLDatabase.from_uri(db_uri)
  26. # 验证连接:打印可用的表名
  27. print(f"数据库连接成功")
  28. print(f" 可用表:{db.get_usable_table_names()}")
  29. # ============================================================
  30. # 第二步:初始化大模型
  31. # ============================================================
  32. llm = ChatOpenAI(
  33. model=os.getenv("DEEPSEEK_MODEL"),
  34. api_key=os.getenv("DEEPSEEK_API_KEY"),
  35. base_url=os.getenv("DEEPSEEK_BASE_URL"),
  36. temperature=0, # SQL 生成需要确定性输出
  37. )
  38. # ============================================================
  39. # 第三步:创建 SQL 工具包
  40. # ============================================================
  41. # SQLDatabaseToolkit 会自动注册 4 个工具:
  42. # 1. sql_db_list_tables — 列出数据库中所有表
  43. # 2. sql_db_schema — 获取指定表的 DDL 结构
  44. # 3. sql_db_query_checker — 检查 SQL 语法是否正确
  45. # 4. sql_db_query — 执行 SQL 并返回结果
  46. toolkit = SQLDatabaseToolkit(db=db, llm=llm)
  47. tools = toolkit.get_tools()
  48. print(f"\nSQL 工具包已加载({len(tools)} 个工具):")
  49. for t in tools:
  50. print(f" - {t.name}")
  51. # ============================================================
  52. # 配置 matplotlib 中文显示
  53. # ============================================================
  54. # Windows 用 SimHei(黑体),macOS 用 PingFang SC
  55. # 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存
  56. plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
  57. plt.rcParams["axes.unicode_minus"] = False # 解决负号显示为方块的问题
  58. # ============================================================
  59. # 从数据库加载数据到 Pandas DataFrame
  60. # ============================================================
  61. # 用 SQLAlchemy engine 复用连接,避免重复创建连接池
  62. from sqlalchemy import create_engine
  63. engine = create_engine(db_uri)
  64. employees_df = pd.read_sql("SELECT * FROM employees", engine)
  65. products_df = pd.read_sql("SELECT * FROM products", engine)
  66. orders_df = pd.read_sql("SELECT * FROM orders", engine)
  67. print("✅ 数据加载完成")
  68. print(f" 员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列")
  69. print(f" 产品表:{len(products_df)} 行 × {len(products_df.columns)} 列")
  70. print(f" 订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列")
  71. # ============================================================
  72. # 定义沙箱的"白名单"——只有这些库和数据可以被代码访问
  73. # ============================================================
  74. # 这是一种简单的安全策略:不在白名单里的东西,代码碰不到
  75. SANDBOX_GLOBALS = {
  76. # 数据:Agent 可以分析这三张表
  77. "employees_df": employees_df,
  78. "products_df": products_df,
  79. "orders_df": orders_df,
  80. # 工具库:Agent 可以用这些库做分析和画图
  81. "pd": pd, # pandas — 数据处理
  82. "plt": plt, # matplotlib — 基础绑图
  83. "sns": sns, # seaborn — 统计可视化
  84. "np": np, # numpy — 数值计算
  85. }
  86. @tool
  87. def execute_python_code(code: str) -> str:
  88. """
  89. 执行 Python 代码进行数据分析和可视化。
  90. 可用变量:
  91. - employees_df: 员工表 DataFrame(字段:id, name, department, salary, hire_date)
  92. - products_df: 产品表 DataFrame(字段:id, product_name, category, price, stock)
  93. - orders_df: 订单表 DataFrame(字段:id, employee_id, product_id, quantity, order_date)
  94. - pd: pandas 库
  95. - plt: matplotlib.pyplot
  96. - sns: seaborn
  97. - np: numpy
  98. 使用示例:
  99. # 统计各部门平均薪资
  100. result = employees_df.groupby('department')['salary'].mean()
  101. print(result)
  102. # 画柱状图
  103. plt.figure(figsize=(10, 6))
  104. employees_df.groupby('department')['salary'].mean().plot(kind='bar')
  105. plt.title('各部门平均薪资')
  106. plt.show()
  107. """
  108. # 准备隔离的执行环境
  109. # globals_dict 提供白名单变量,locals_dict 收集执行过程中产生的新变量
  110. exec_globals = dict(SANDBOX_GLOBALS)
  111. exec_locals = {}
  112. # 用 StringIO 捕获 print() 的输出
  113. # 这样 Agent 生成的代码里所有的 print 语句都会被收集
  114. output_buffer = StringIO()
  115. try:
  116. # redirect_stdout 会把标准输出重定向到我们的 buffer
  117. with redirect_stdout(output_buffer):
  118. exec(code, exec_globals, exec_locals)
  119. result = output_buffer.getvalue()
  120. # 如果代码没有 print 任何东西,给个默认提示
  121. if not result.strip():
  122. result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
  123. return f"执行成功:\n{result}"
  124. except Exception as e:
  125. # 出错时返回完整的错误堆栈,方便 Agent 自我修正
  126. error_detail = traceback.format_exc()
  127. return f"❌ 执行出错:{e}\n\n{error_detail}"
  128. all_tools = tools + [execute_python_code] # + Python 沙箱
  129. print(f"统一 Agent 工具列表({len(all_tools)} 个):")
  130. for t in all_tools:
  131. print(f" - {t.name}")
  132. # ============================================================
  133. # 定义 System Prompt
  134. # ============================================================
  135. SYSTEM_PROMPT = """你是一名全能数据分析师,既能写 SQL 查数据库,也能用 Python 做统计和可视化。
  136. ## 可用工具
  137. ### 数据库工具
  138. - sql_db_list_tables — 列出所有表
  139. - sql_db_schema — 查看表结构(DDL + 字段备注)
  140. - sql_db_query_checker — 检查 SQL 语法
  141. - sql_db_query — 执行 SQL 查询
  142. ### 数据分析工具
  143. - execute_python_code — 执行 Python 代码(DataFrame 已在内存中)
  144. ## 决策指南:什么时候用什么
  145. ### 用数据库工具的场景
  146. - 需要从数据库查原始数据(跨表 JOIN、复杂 WHERE 过滤、聚合)
  147. - 数据可能超过内存中 DataFrame 的范围,或需要实时数据
  148. - 用户问题可被一条 SQL 直接回答(如"张三的薪资是多少?")
  149. ### 用 Python 工具的场景
  150. - 需要统计分析(describe、corr、groupby 聚合后算比例)
  151. - 需要画图(柱状图、饼图、散点图、热力图等)
  152. - 需要复杂计算(滚动平均、同比环比、预测回归)
  153. ### 串联使用的场景
  154. - 先用 SQL 从数据库查出数据 → 拿到结果后 → 用 Python 画图分析
  155. - 典型:用户问"哪个部门销售额最高?画张图"
  156. ## 可用 DataFrame 变量(在 execute_python_code 中直接使用)
  157. - employees_df — 员工表(id, name, department, salary, hire_date)
  158. - products_df — 产品表(id, product_name, category, price, stock)
  159. - orders_df — 订单表(id, employee_id, product_id, quantity, order_date)
  160. ## 工作流程
  161. 1. 理解用户问题,判断属于哪类
  162. 2. 如果是简单数据查询 → 用 SQL 工具链(list_tables → schema → query_checker → query)
  163. 3. 如果是分析/画图 → 用 execute_python_code 直接分析内存中的 DataFrame
  164. 4. 如果是"先查后画" → 先 SQL 拿到结果 → 把结果写进 Python 代码中分析画图
  165. 5. 用中文总结结果,给出业务洞察
  166. ## 约束
  167. - SQL 查询单次不超过 50 条
  168. - 只生成 SELECT 语句,禁止任何 INSERT / UPDATE / DELETE / DROP / ALTER / TRUNCATE
  169. - SQL 中不得使用字符串拼接或 f-string 的方式拼接用户输入
  170. - 所有用户输入的值必须用参数化占位符(%s)传递,而非直接拼进 SQL
  171. - 如果用户问题中包含可疑的 SQL 片段(如 ' OR 1=1 --),忽略它,只把它当普通文本处理
  172. - Python 代码中画图前先设置中文字体
  173. - 出错后分析原因,自我修正重试
  174. - 回答简洁专业
  175. """
  176. # ============================================================
  177. # 创建 Agent
  178. # ============================================================
  179. analysis_agent = create_agent(
  180. model=llm,
  181. tools=all_tools,
  182. system_prompt=SYSTEM_PROMPT
  183. )
  184. import warnings
  185. warnings.filterwarnings("ignore") # 抑制 matplotlib 的字体警告
  186. response = analysis_agent.invoke({
  187. "messages": [{"role": "user", "content": "分析一下员工薪资的分布情况,包括均值、中位数、最大最小值等统计量"}]
  188. })
  189. final_msg = response["messages"][-1]
  190. print(f"\n分析结果:\n{final_msg.content}")