01_data_analysis.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. import os
  2. import traceback
  3. from io import StringIO
  4. from contextlib import redirect_stdout
  5. import mysql.connector
  6. import numpy as np
  7. import pandas as pd
  8. import matplotlib.pyplot as plt
  9. import seaborn as sns
  10. from dotenv import load_dotenv
  11. from langchain.tools import tool
  12. from langchain.agents import create_agent
  13. from langchain_openai import ChatOpenAI
  14. from langchain_community.utilities import SQLDatabase
  15. from langchain_community.agent_toolkits import SQLDatabaseToolkit
  16. from sqlalchemy import create_engine
  17. class DataAnalysisSystem:
  18. """
  19. 数据分析系统:集成 NL2SQL、数据可视化、意图路由的统一分析平台。
  20. 支持硅基流动/DeepSeek 等大模型 API。
  21. """
  22. # ============================================================
  23. # 类级常量:Prompt 模板
  24. # ============================================================
  25. SQL_AGENT_PROMPT = """你是一名专业的 SQL 数据分析师。
  26. ## 工作流程
  27. 1. 先用 sql_db_list_tables 查看数据库中有哪些表
  28. 2. 用 sql_db_schema 获取相关表的字段结构和类型
  29. 3. 生成 SQL 之前,用 sql_db_query_checker 检查语法
  30. 4. 确认无误后,用 sql_db_query 执行查询
  31. 5. 用中文总结查询结果,给出简洁的业务洞察
  32. ## 约束
  33. - 只使用数据库中实际存在的表和字段,不要凭空编造
  34. - 单次查询结果限制在 50 条以内
  35. - 如果查询出错,分析错误原因后重新生成 SQL
  36. - 回答要简洁专业,不要啰嗦
  37. """
  38. VISUALIZATION_PROMPT = """你是一名资深数据分析师,精通 Python、Pandas 和 Matplotlib 数据可视化。
  39. ## 可用数据
  40. 1. employees_df — 员工表(字段:id, name, department, salary, hire_date)
  41. 2. products_df — 产品表(字段:id, product_name, category, price, stock)
  42. 3. orders_df — 订单表(字段:id, employee_id, product_id, quantity, order_date)
  43. ## 工作流程
  44. 1. 理解用户的分析需求
  45. 2. 用 execute_python_code 工具编写并执行 Python 代码
  46. 3. 先做数据探索(head、describe、info),再做深入分析
  47. 4. 用中文解释分析结果,给出业务洞察
  48. ## 代码规范
  49. - 绑图前设置中文字体:plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']
  50. - 设置 plt.rcParams['axes.unicode_minus'] = False
  51. - 图表尺寸统一用 plt.figure(figsize=(10, 6))
  52. - 必须添加标题、坐标轴标签,让图表自解释
  53. - 用 print() 输出关键统计量,不要只画图不说话
  54. - 图表标题用英文(避免渲染问题),但用中文向用户解释结果
  55. ## 注意事项
  56. - 每次只执行一段完整的代码,不要拆成多段
  57. - 先探索数据结构,再做分析——不要上来就画图
  58. - 结果要有业务洞察,不只是"最大值是 XXX"
  59. """
  60. # ============================================================
  61. # 构造函数
  62. # ============================================================
  63. def __init__(self, env_file: str = ".env"):
  64. """
  65. 初始化数据分析系统。
  66. Args:
  67. env_file: 环境变量文件路径,默认当前目录的 .env
  68. """
  69. # 加载环境变量
  70. load_dotenv(dotenv_path=env_file)
  71. # 初始化 LLM
  72. self.api_key = os.getenv("API_KEY")
  73. self.model_name = os.getenv("MODEL_NAME")
  74. self.base_url = os.getenv("BASE_URL")
  75. self.llm = ChatOpenAI(
  76. model=self.model_name,
  77. api_key=self.api_key,
  78. base_url=self.base_url,
  79. temperature=0,
  80. )
  81. print("✅ 大语言模型初始化完成")
  82. # 构建数据库 URI
  83. self.db_uri = (
  84. f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
  85. f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
  86. )
  87. print("✅ 数据库 URI 构建完成")
  88. # 数据缓存
  89. self.employees_df = None
  90. self.products_df = None
  91. self.orders_df = None
  92. # Agent 缓存(延迟初始化)
  93. self._sql_agent = None
  94. self._visual_agent = None
  95. self._sandbox_globals = None
  96. self.together_agent = None
  97. # ============================================================
  98. # 公共方法:测试模型连接
  99. # ============================================================
  100. def test_llm(self, prompt: str = "什么是量子纠缠?") -> str:
  101. """测试 LLM 连接是否正常。"""
  102. response = self.llm.invoke(prompt)
  103. return response.content
  104. # ============================================================
  105. # 公共方法:初始化数据库(建表 + 插数据)
  106. # ============================================================
  107. def init_database(self) -> None:
  108. """
  109. 创建演示数据库:employees、products、orders 三张表,
  110. 并插入示例数据。
  111. """
  112. conn = mysql.connector.connect(
  113. host=os.getenv("DB_HOST"),
  114. port=int(os.getenv("DB_PORT")),
  115. user=os.getenv("DB_USER"),
  116. password=os.getenv("DB_PASSWORD"),
  117. database=os.getenv("DB_NAME"),
  118. )
  119. cursor = conn.cursor()
  120. # 建表
  121. cursor.execute("""
  122. CREATE TABLE IF NOT EXISTS employees (
  123. id INT PRIMARY KEY AUTO_INCREMENT,
  124. name VARCHAR(50) NOT NULL,
  125. department VARCHAR(50) NOT NULL,
  126. salary DECIMAL(10,2) NOT NULL,
  127. hire_date DATE
  128. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
  129. """)
  130. cursor.execute("""
  131. CREATE TABLE IF NOT EXISTS products (
  132. id INT PRIMARY KEY AUTO_INCREMENT,
  133. product_name VARCHAR(100) NOT NULL,
  134. category VARCHAR(50) NOT NULL,
  135. price DECIMAL(10,2) NOT NULL,
  136. stock INT DEFAULT 0
  137. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
  138. """)
  139. cursor.execute("""
  140. CREATE TABLE IF NOT EXISTS orders (
  141. id INT PRIMARY KEY AUTO_INCREMENT,
  142. employee_id INT NOT NULL,
  143. product_id INT NOT NULL,
  144. quantity INT NOT NULL,
  145. order_date DATE NOT NULL,
  146. FOREIGN KEY (employee_id) REFERENCES employees(id),
  147. FOREIGN KEY (product_id) REFERENCES products(id)
  148. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
  149. """)
  150. # 插入数据
  151. employees_data = [
  152. (1, "张三", "技术部", 20000.00, "2023-01-15"),
  153. (2, "李四", "销售部", 11000.00, "2023-02-20"),
  154. (3, "王五", "技术部", 16000.00, "2022-11-10"),
  155. (4, "赵六", "人力资源", 5000.00, "2023-03-01"),
  156. (5, "钱七", "销售部", 17000.00, "2022-12-05"),
  157. ]
  158. products_data = [
  159. (1, "笔记本电脑", "电子产品", 6999.00, 500),
  160. (2, "机械键盘", "电子产品", 399.00, 1000),
  161. (3, "办公椅", "办公用品", 499.00, 300),
  162. (4, "显示器", "电子产品", 1200.00, 400),
  163. ]
  164. orders_data = [
  165. (1, 1, 1, 2, "2024-01-15"),
  166. (2, 2, 2, 15, "2024-01-16"),
  167. (3, 3, 1, 10, "2024-01-17"),
  168. (4, 5, 3, 6, "2024-01-18"),
  169. (5, 2, 4, 5, "2024-01-19"),
  170. ]
  171. cursor.executemany(
  172. "INSERT IGNORE INTO employees VALUES (%s,%s,%s,%s,%s)", employees_data
  173. )
  174. cursor.executemany(
  175. "INSERT IGNORE INTO products VALUES (%s,%s,%s,%s,%s)", products_data
  176. )
  177. cursor.executemany(
  178. "INSERT IGNORE INTO orders VALUES (%s,%s,%s,%s,%s)", orders_data
  179. )
  180. conn.commit()
  181. conn.close()
  182. print("✅ 数据库初始化完成")
  183. print(" - employees 表:5 条员工记录")
  184. print(" - products 表:4 条产品记录")
  185. print(" - orders 表:5 条订单记录")
  186. # ============================================================
  187. # 内部方法:获取 SQL Agent(延迟初始化)
  188. # ============================================================
  189. def _get_sql_agent(self):
  190. """延迟初始化 SQL Agent。"""
  191. if self._sql_agent is not None:
  192. return self._sql_agent
  193. db = SQLDatabase.from_uri(self.db_uri)
  194. print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
  195. toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
  196. tools = toolkit.get_tools()
  197. print(f"SQL 工具包已加载({len(tools)} 个工具):")
  198. for t in tools:
  199. print(f" - {t.name}")
  200. self._sql_agent = create_agent(
  201. model=self.llm,
  202. tools=tools,
  203. system_prompt=self.SQL_AGENT_PROMPT,
  204. )
  205. print("NL2SQL Agent 创建完成")
  206. return self._sql_agent
  207. # ============================================================
  208. # 内部方法:加载数据到 DataFrame(用于可视化)
  209. # ============================================================
  210. def _load_dataframes(self):
  211. """从数据库加载数据到 Pandas DataFrame。"""
  212. if self.employees_df is not None:
  213. return # 已加载,跳过
  214. # 配置 matplotlib 中文显示
  215. plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
  216. plt.rcParams["axes.unicode_minus"] = False
  217. engine = create_engine(self.db_uri)
  218. self.employees_df = pd.read_sql("SELECT * FROM employees", engine)
  219. self.products_df = pd.read_sql("SELECT * FROM products", engine)
  220. self.orders_df = pd.read_sql("SELECT * FROM orders", engine)
  221. print("✅ 数据加载完成")
  222. print(f" 员工表:{len(self.employees_df)} 行")
  223. print(f" 产品表:{len(self.products_df)} 行")
  224. print(f" 订单表:{len(self.orders_df)} 行")
  225. # ============================================================
  226. # 内部方法:获取可视化 Agent(延迟初始化)
  227. # ============================================================
  228. def _get_visual_agent(self):
  229. """延迟初始化可视化 Agent。"""
  230. if self._visual_agent is not None:
  231. return self._visual_agent
  232. self._load_dataframes()
  233. # 构建沙箱全局变量
  234. self._sandbox_globals = {
  235. "employees_df": self.employees_df,
  236. "products_df": self.products_df,
  237. "orders_df": self.orders_df,
  238. "pd": pd,
  239. "plt": plt,
  240. "sns": sns,
  241. "np": np,
  242. }
  243. # 创建工具
  244. @tool
  245. def execute_python_code(code: str) -> str:
  246. """
  247. 执行 Python 代码进行数据分析和可视化。
  248. 可用变量:
  249. - employees_df, products_df, orders_df
  250. - pd, plt, sns, np
  251. """
  252. exec_globals = dict(self._sandbox_globals)
  253. exec_locals = {}
  254. output_buffer = StringIO()
  255. try:
  256. with redirect_stdout(output_buffer):
  257. exec(code, exec_globals, exec_locals)
  258. result = output_buffer.getvalue()
  259. if not result.strip():
  260. result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
  261. return f"执行成功:\n{result}"
  262. except Exception as e:
  263. error_detail = traceback.format_exc()
  264. return f"❌ 执行出错:{e}\n\n{error_detail}"
  265. viz_tools = [execute_python_code]
  266. self._visual_agent = create_agent(
  267. model=self.llm,
  268. tools=viz_tools,
  269. system_prompt=self.VISUALIZATION_PROMPT,
  270. )
  271. print("✅ 数据可视化 Agent 创建完成")
  272. return self._visual_agent
  273. # ============================================================
  274. # 公共方法:运行 SQL 查询
  275. # ============================================================
  276. def query(self, user_query: str) -> str:
  277. """
  278. 使用 SQL Agent 执行自然语言到 SQL 的查询。
  279. Args:
  280. user_query: 用户的自然语言问题
  281. Returns:
  282. Agent 的分析结果文本
  283. """
  284. agent = self._get_sql_agent()
  285. result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
  286. return result["messages"][-1].content
  287. # ============================================================
  288. # 公共方法:运行可视化分析
  289. # ============================================================
  290. def visualize(self, user_query: str) -> str:
  291. """
  292. 使用可视化 Agent 进行数据分析和图表生成。
  293. Args:
  294. user_query: 用户的分析需求描述
  295. Returns:
  296. Agent 的分析结果文本
  297. """
  298. agent = self._get_visual_agent()
  299. result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
  300. return result["messages"][-1].content
  301. # ============================================================
  302. # 公共方法:统一入口(自动意图路由)
  303. # ============================================================
  304. def analyze(self, user_query: str) -> str:
  305. """
  306. 统一分析入口:自动判断用户意图,路由到对应 Agent。
  307. 意图分类:
  308. - query:纯数据库查询
  309. - visualize:纯数据可视化/统计分析
  310. - both:先查询再可视化
  311. Args:
  312. user_query: 用户的自然语言问题
  313. Returns:
  314. 分析结果文本
  315. """
  316. classification_prompt = f"""判断以下用户问题属于哪种类型:
  317. - "query":需要查询数据库获取数据
  318. - "visualize":需要画图或做统计分析
  319. - "both":需要先查数据,再画图分析
  320. 用户问题:{user_query}
  321. 只回复一个词:query / visualize / both"""
  322. intent = self.llm.invoke(classification_prompt).content.strip().lower()
  323. if intent == "query":
  324. return self.query(user_query)
  325. elif intent == "visualize":
  326. return self.visualize(user_query)
  327. else: # both
  328. data_result = self.query(user_query)
  329. viz_input = f"基于以下数据进行可视化分析:\n{data_result}\n\n原始问题:{user_query}"
  330. return self.visualize(viz_input)
  331. # ============================================================
  332. # 内部方法:获取或创建统一 Agent
  333. # ============================================================
  334. def get_together_agent(self):
  335. TOGETHER_SYSTEM_PROMPT = """你是一名全能数据分析师,精通 SQL 查询和 Python 数据分析可视化。
  336. ## 可用工具
  337. 1. **SQL 工具组**(数据库操作):
  338. - `sql_db_list_tables` — 查看数据库中有哪些表
  339. - `sql_db_schema` — 获取指定表的字段结构
  340. - `sql_db_query_checker` — 检查 SQL 语法是否正确
  341. - `sql_db_query` — 执行 SQL 查询并返回结果
  342. 2. **Python 代码执行工具**(数据分析与可视化):
  343. - `execute_python_code` — 在沙箱中执行 Python 代码
  344. ## 可用数据(Python 沙箱中已预加载)
  345. - `employees_df` — 员工表(id, name, department, salary, hire_date)
  346. - `products_df` — 产品表(id, product_name, category, price, stock)
  347. - `orders_df` — 订单表(id, employee_id, product_id, quantity, order_date)
  348. - `pd` (pandas), `plt` (matplotlib), `sns` (seaborn), `np` (numpy)
  349. ## 工作流程
  350. 1. **理解需求**:分析用户问题的核心诉求
  351. 2. **获取元数据**:如需查数据库,先用 `sql_db_list_tables` 和 `sql_db_schema` 了解表结构
  352. 3. **选择工具**:
  353. - 纯数据查询 → 用 SQL 工具
  354. - 统计分析、画图 → 用 `execute_python_code`
  355. - 复杂分析 → 先用 SQL 查数据,再用 Python 做深度分析
  356. 4. **执行与验证**:SQL 先用 `sql_db_query_checker` 检查语法;Python 代码确保正确
  357. 5. **总结输出**:用中文给出简洁的业务洞察
  358. ## 约束
  359. - 只使用数据库中实际存在的表和字段
  360. - SQL 单次查询结果限制在 50 条以内
  361. - Python 代码绑图前设置中文字体:`plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']`
  362. - 设置 `plt.rcParams['axes.unicode_minus'] = False`
  363. - 图表尺寸用 `plt.figure(figsize=(10, 6))`
  364. - 必须添加标题、坐标轴标签
  365. - 用 `print()` 输出关键统计量
  366. - 回答简洁专业,不要啰嗦
  367. """
  368. """延迟初始化统一 Agent,合并 SQL 工具 + Python 执行工具。"""
  369. if self.together_agent is not None:
  370. return self.together_agent
  371. # ========== 1. SQL 工具组 ==========
  372. db = SQLDatabase.from_uri(self.db_uri)
  373. print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
  374. sql_toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
  375. sql_tools = sql_toolkit.get_tools()
  376. print(f"SQL 工具加载完成({len(sql_tools)} 个):")
  377. for t in sql_tools:
  378. print(f" - {t.name}")
  379. # ========== 2. Python 代码执行工具 ==========
  380. self._load_dataframes()
  381. self._sandbox_globals = {
  382. "employees_df": self.employees_df,
  383. "products_df": self.products_df,
  384. "orders_df": self.orders_df,
  385. "pd": pd,
  386. "plt": plt,
  387. "sns": sns,
  388. "np": np,
  389. }
  390. @tool
  391. def execute_python_code(code: str) -> str:
  392. """
  393. 执行 Python 代码进行数据分析和可视化。
  394. 可用变量:employees_df, products_df, orders_df, pd, plt, sns, np
  395. """
  396. exec_globals = dict(self._sandbox_globals)
  397. exec_locals = {}
  398. output_buffer = StringIO()
  399. try:
  400. with redirect_stdout(output_buffer):
  401. exec(code, exec_globals, exec_locals)
  402. result = output_buffer.getvalue()
  403. if not result.strip():
  404. result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
  405. return f"执行成功:\n{result}"
  406. except Exception as e:
  407. error_detail = traceback.format_exc()
  408. return f"❌ 执行出错:{e}\n\n{error_detail}"
  409. # ========== 3. 合并所有工具 ==========
  410. all_tools = sql_tools + [execute_python_code]
  411. print(f"\n✅ 工具合并完成,共 {len(all_tools)} 个工具:")
  412. for t in all_tools:
  413. print(f" - {t.name}")
  414. # ========== 4. 创建统一 Agent ==========
  415. self.together_agent = create_agent(
  416. model=self.llm,
  417. tools=all_tools,
  418. system_prompt=TOGETHER_SYSTEM_PROMPT,
  419. )
  420. print("\n🤖 统一分析 Agent 创建完成!")
  421. return self.together_agent
  422. # ============================================================
  423. # 统⼀ Agent 模式(作业)
  424. # ============================================================
  425. def together_analyze(self, user_query: str) -> str:
  426. """
  427. 统一分析入口:LLM 自主决定调用 SQL 工具或 Python 工具。
  428. Args:
  429. user_query: 用户的自然语言问题
  430. Returns:
  431. Agent 的分析结果
  432. """
  433. agent = self.get_together_agent()
  434. result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
  435. return result["messages"][-1].content
  436. # ============================================================
  437. # 使用示例
  438. # ============================================================
  439. if __name__ == "__main__":
  440. # 1. 创建系统实例
  441. system = DataAnalysisSystem()
  442. # 2. 测试 LLM 连接
  443. print("\n--- LLM 测试 ---")
  444. print(system.test_llm("你好,请用一句话介绍自己"))
  445. # 3. 初始化数据库(只需执行一次)
  446. # print("\n--- 初始化数据库 ---")
  447. # system.init_database()
  448. # 4. 运行分析
  449. # print("\n--- SQL 查询 ---")
  450. # result = system.query("技术部的平均工资是多少?")
  451. # print(result)
  452. #
  453. # print("\n--- 可视化分析 ---")
  454. # result = system.visualize("统计各部门人数并画柱状图")
  455. # print(result)
  456. # print("\n--- 自动路由 ---")
  457. # result = system.analyze("销售部谁的总销售额最高?")
  458. # print(result)
  459. print("\n--- 统一分析:LLM 自动选择工具 ---")
  460. result = system.together_analyze("统计各部门人数并画柱状图?")
  461. print(result)