| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- import os
- import traceback
- from io import StringIO
- from contextlib import redirect_stdout
- import mysql.connector
- import numpy as np
- import pandas as pd
- import matplotlib.pyplot as plt
- import seaborn as sns
- from dotenv import load_dotenv
- from langchain.tools import tool
- from langchain.agents import create_agent
- from langchain_openai import ChatOpenAI
- from langchain_community.utilities import SQLDatabase
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
- from sqlalchemy import create_engine
- class DataAnalysisSystem:
- """
- 数据分析系统:集成 NL2SQL、数据可视化、意图路由的统一分析平台。
- 支持硅基流动/DeepSeek 等大模型 API。
- """
- # ============================================================
- # 类级常量:Prompt 模板
- # ============================================================
- SQL_AGENT_PROMPT = """你是一名专业的 SQL 数据分析师。
- ## 工作流程
- 1. 先用 sql_db_list_tables 查看数据库中有哪些表
- 2. 用 sql_db_schema 获取相关表的字段结构和类型
- 3. 生成 SQL 之前,用 sql_db_query_checker 检查语法
- 4. 确认无误后,用 sql_db_query 执行查询
- 5. 用中文总结查询结果,给出简洁的业务洞察
-
- ## 约束
- - 只使用数据库中实际存在的表和字段,不要凭空编造
- - 单次查询结果限制在 50 条以内
- - 如果查询出错,分析错误原因后重新生成 SQL
- - 回答要简洁专业,不要啰嗦
- """
- VISUALIZATION_PROMPT = """你是一名资深数据分析师,精通 Python、Pandas 和 Matplotlib 数据可视化。
- ## 可用数据
- 1. employees_df — 员工表(字段:id, name, department, salary, hire_date)
- 2. products_df — 产品表(字段:id, product_name, category, price, stock)
- 3. orders_df — 订单表(字段:id, employee_id, product_id, quantity, order_date)
-
- ## 工作流程
- 1. 理解用户的分析需求
- 2. 用 execute_python_code 工具编写并执行 Python 代码
- 3. 先做数据探索(head、describe、info),再做深入分析
- 4. 用中文解释分析结果,给出业务洞察
- ## 代码规范
- - 绑图前设置中文字体:plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']
- - 设置 plt.rcParams['axes.unicode_minus'] = False
- - 图表尺寸统一用 plt.figure(figsize=(10, 6))
- - 必须添加标题、坐标轴标签,让图表自解释
- - 用 print() 输出关键统计量,不要只画图不说话
- - 图表标题用英文(避免渲染问题),但用中文向用户解释结果
-
- ## 注意事项
- - 每次只执行一段完整的代码,不要拆成多段
- - 先探索数据结构,再做分析——不要上来就画图
- - 结果要有业务洞察,不只是"最大值是 XXX"
- """
- # ============================================================
- # 构造函数
- # ============================================================
- def __init__(self, env_file: str = ".env"):
- """
- 初始化数据分析系统。
- Args:
- env_file: 环境变量文件路径,默认当前目录的 .env
- """
- # 加载环境变量
- load_dotenv(dotenv_path=env_file)
- # 初始化 LLM
- self.api_key = os.getenv("API_KEY")
- self.model_name = os.getenv("MODEL_NAME")
- self.base_url = os.getenv("BASE_URL")
- self.llm = ChatOpenAI(
- model=self.model_name,
- api_key=self.api_key,
- base_url=self.base_url,
- temperature=0,
- )
- print("✅ 大语言模型初始化完成")
- # 构建数据库 URI
- self.db_uri = (
- f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
- f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
- )
- print("✅ 数据库 URI 构建完成")
- # 数据缓存
- self.employees_df = None
- self.products_df = None
- self.orders_df = None
- # Agent 缓存(延迟初始化)
- self._sql_agent = None
- self._visual_agent = None
- self._sandbox_globals = None
- self.together_agent = None
- # ============================================================
- # 公共方法:测试模型连接
- # ============================================================
- def test_llm(self, prompt: str = "什么是量子纠缠?") -> str:
- """测试 LLM 连接是否正常。"""
- response = self.llm.invoke(prompt)
- return response.content
- # ============================================================
- # 公共方法:初始化数据库(建表 + 插数据)
- # ============================================================
- def init_database(self) -> None:
- """
- 创建演示数据库:employees、products、orders 三张表,
- 并插入示例数据。
- """
- conn = mysql.connector.connect(
- host=os.getenv("DB_HOST"),
- port=int(os.getenv("DB_PORT")),
- user=os.getenv("DB_USER"),
- password=os.getenv("DB_PASSWORD"),
- database=os.getenv("DB_NAME"),
- )
- cursor = conn.cursor()
- # 建表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS employees (
- id INT PRIMARY KEY AUTO_INCREMENT,
- name VARCHAR(50) NOT NULL,
- department VARCHAR(50) NOT NULL,
- salary DECIMAL(10,2) NOT NULL,
- hire_date DATE
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
- """)
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS products (
- id INT PRIMARY KEY AUTO_INCREMENT,
- product_name VARCHAR(100) NOT NULL,
- category VARCHAR(50) NOT NULL,
- price DECIMAL(10,2) NOT NULL,
- stock INT DEFAULT 0
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
- """)
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS orders (
- id INT PRIMARY KEY AUTO_INCREMENT,
- employee_id INT NOT NULL,
- product_id INT NOT NULL,
- quantity INT NOT NULL,
- order_date DATE NOT NULL,
- FOREIGN KEY (employee_id) REFERENCES employees(id),
- FOREIGN KEY (product_id) REFERENCES products(id)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
- """)
- # 插入数据
- employees_data = [
- (1, "张三", "技术部", 20000.00, "2023-01-15"),
- (2, "李四", "销售部", 11000.00, "2023-02-20"),
- (3, "王五", "技术部", 16000.00, "2022-11-10"),
- (4, "赵六", "人力资源", 5000.00, "2023-03-01"),
- (5, "钱七", "销售部", 17000.00, "2022-12-05"),
- ]
- products_data = [
- (1, "笔记本电脑", "电子产品", 6999.00, 500),
- (2, "机械键盘", "电子产品", 399.00, 1000),
- (3, "办公椅", "办公用品", 499.00, 300),
- (4, "显示器", "电子产品", 1200.00, 400),
- ]
- orders_data = [
- (1, 1, 1, 2, "2024-01-15"),
- (2, 2, 2, 15, "2024-01-16"),
- (3, 3, 1, 10, "2024-01-17"),
- (4, 5, 3, 6, "2024-01-18"),
- (5, 2, 4, 5, "2024-01-19"),
- ]
- cursor.executemany(
- "INSERT IGNORE INTO employees VALUES (%s,%s,%s,%s,%s)", employees_data
- )
- cursor.executemany(
- "INSERT IGNORE INTO products VALUES (%s,%s,%s,%s,%s)", products_data
- )
- cursor.executemany(
- "INSERT IGNORE INTO orders VALUES (%s,%s,%s,%s,%s)", orders_data
- )
- conn.commit()
- conn.close()
- print("✅ 数据库初始化完成")
- print(" - employees 表:5 条员工记录")
- print(" - products 表:4 条产品记录")
- print(" - orders 表:5 条订单记录")
- # ============================================================
- # 内部方法:获取 SQL Agent(延迟初始化)
- # ============================================================
- def _get_sql_agent(self):
- """延迟初始化 SQL Agent。"""
- if self._sql_agent is not None:
- return self._sql_agent
- db = SQLDatabase.from_uri(self.db_uri)
- print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
- toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
- tools = toolkit.get_tools()
- print(f"SQL 工具包已加载({len(tools)} 个工具):")
- for t in tools:
- print(f" - {t.name}")
- self._sql_agent = create_agent(
- model=self.llm,
- tools=tools,
- system_prompt=self.SQL_AGENT_PROMPT,
- )
- print("NL2SQL Agent 创建完成")
- return self._sql_agent
- # ============================================================
- # 内部方法:加载数据到 DataFrame(用于可视化)
- # ============================================================
- def _load_dataframes(self):
- """从数据库加载数据到 Pandas DataFrame。"""
- if self.employees_df is not None:
- return # 已加载,跳过
- # 配置 matplotlib 中文显示
- plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
- plt.rcParams["axes.unicode_minus"] = False
- engine = create_engine(self.db_uri)
- self.employees_df = pd.read_sql("SELECT * FROM employees", engine)
- self.products_df = pd.read_sql("SELECT * FROM products", engine)
- self.orders_df = pd.read_sql("SELECT * FROM orders", engine)
- print("✅ 数据加载完成")
- print(f" 员工表:{len(self.employees_df)} 行")
- print(f" 产品表:{len(self.products_df)} 行")
- print(f" 订单表:{len(self.orders_df)} 行")
- # ============================================================
- # 内部方法:获取可视化 Agent(延迟初始化)
- # ============================================================
- def _get_visual_agent(self):
- """延迟初始化可视化 Agent。"""
- if self._visual_agent is not None:
- return self._visual_agent
- self._load_dataframes()
- # 构建沙箱全局变量
- self._sandbox_globals = {
- "employees_df": self.employees_df,
- "products_df": self.products_df,
- "orders_df": self.orders_df,
- "pd": pd,
- "plt": plt,
- "sns": sns,
- "np": np,
- }
- # 创建工具
- @tool
- def execute_python_code(code: str) -> str:
- """
- 执行 Python 代码进行数据分析和可视化。
- 可用变量:
- - employees_df, products_df, orders_df
- - pd, plt, sns, np
- """
- exec_globals = dict(self._sandbox_globals)
- exec_locals = {}
- output_buffer = StringIO()
- try:
- with redirect_stdout(output_buffer):
- exec(code, exec_globals, exec_locals)
- result = output_buffer.getvalue()
- if not result.strip():
- result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
- return f"执行成功:\n{result}"
- except Exception as e:
- error_detail = traceback.format_exc()
- return f"❌ 执行出错:{e}\n\n{error_detail}"
- viz_tools = [execute_python_code]
- self._visual_agent = create_agent(
- model=self.llm,
- tools=viz_tools,
- system_prompt=self.VISUALIZATION_PROMPT,
- )
- print("✅ 数据可视化 Agent 创建完成")
- return self._visual_agent
- # ============================================================
- # 公共方法:运行 SQL 查询
- # ============================================================
- def query(self, user_query: str) -> str:
- """
- 使用 SQL Agent 执行自然语言到 SQL 的查询。
- Args:
- user_query: 用户的自然语言问题
- Returns:
- Agent 的分析结果文本
- """
- agent = self._get_sql_agent()
- result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
- return result["messages"][-1].content
- # ============================================================
- # 公共方法:运行可视化分析
- # ============================================================
- def visualize(self, user_query: str) -> str:
- """
- 使用可视化 Agent 进行数据分析和图表生成。
- Args:
- user_query: 用户的分析需求描述
- Returns:
- Agent 的分析结果文本
- """
- agent = self._get_visual_agent()
- result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
- return result["messages"][-1].content
- # ============================================================
- # 公共方法:统一入口(自动意图路由)
- # ============================================================
- def analyze(self, user_query: str) -> str:
- """
- 统一分析入口:自动判断用户意图,路由到对应 Agent。
- 意图分类:
- - query:纯数据库查询
- - visualize:纯数据可视化/统计分析
- - both:先查询再可视化
- Args:
- user_query: 用户的自然语言问题
- Returns:
- 分析结果文本
- """
- classification_prompt = f"""判断以下用户问题属于哪种类型:
- - "query":需要查询数据库获取数据
- - "visualize":需要画图或做统计分析
- - "both":需要先查数据,再画图分析
- 用户问题:{user_query}
- 只回复一个词:query / visualize / both"""
- intent = self.llm.invoke(classification_prompt).content.strip().lower()
- if intent == "query":
- return self.query(user_query)
- elif intent == "visualize":
- return self.visualize(user_query)
- else: # both
- data_result = self.query(user_query)
- viz_input = f"基于以下数据进行可视化分析:\n{data_result}\n\n原始问题:{user_query}"
- return self.visualize(viz_input)
- # ============================================================
- # 内部方法:获取或创建统一 Agent
- # ============================================================
- def get_together_agent(self):
- TOGETHER_SYSTEM_PROMPT = """你是一名全能数据分析师,精通 SQL 查询和 Python 数据分析可视化。
- ## 可用工具
- 1. **SQL 工具组**(数据库操作):
- - `sql_db_list_tables` — 查看数据库中有哪些表
- - `sql_db_schema` — 获取指定表的字段结构
- - `sql_db_query_checker` — 检查 SQL 语法是否正确
- - `sql_db_query` — 执行 SQL 查询并返回结果
- 2. **Python 代码执行工具**(数据分析与可视化):
- - `execute_python_code` — 在沙箱中执行 Python 代码
- ## 可用数据(Python 沙箱中已预加载)
- - `employees_df` — 员工表(id, name, department, salary, hire_date)
- - `products_df` — 产品表(id, product_name, category, price, stock)
- - `orders_df` — 订单表(id, employee_id, product_id, quantity, order_date)
- - `pd` (pandas), `plt` (matplotlib), `sns` (seaborn), `np` (numpy)
- ## 工作流程
- 1. **理解需求**:分析用户问题的核心诉求
- 2. **获取元数据**:如需查数据库,先用 `sql_db_list_tables` 和 `sql_db_schema` 了解表结构
- 3. **选择工具**:
- - 纯数据查询 → 用 SQL 工具
- - 统计分析、画图 → 用 `execute_python_code`
- - 复杂分析 → 先用 SQL 查数据,再用 Python 做深度分析
- 4. **执行与验证**:SQL 先用 `sql_db_query_checker` 检查语法;Python 代码确保正确
- 5. **总结输出**:用中文给出简洁的业务洞察
- ## 约束
- - 只使用数据库中实际存在的表和字段
- - SQL 单次查询结果限制在 50 条以内
- - Python 代码绑图前设置中文字体:`plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']`
- - 设置 `plt.rcParams['axes.unicode_minus'] = False`
- - 图表尺寸用 `plt.figure(figsize=(10, 6))`
- - 必须添加标题、坐标轴标签
- - 用 `print()` 输出关键统计量
- - 回答简洁专业,不要啰嗦
- """
- """延迟初始化统一 Agent,合并 SQL 工具 + Python 执行工具。"""
- if self.together_agent is not None:
- return self.together_agent
- # ========== 1. SQL 工具组 ==========
- db = SQLDatabase.from_uri(self.db_uri)
- print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
- sql_toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
- sql_tools = sql_toolkit.get_tools()
- print(f"SQL 工具加载完成({len(sql_tools)} 个):")
- for t in sql_tools:
- print(f" - {t.name}")
- # ========== 2. Python 代码执行工具 ==========
- self._load_dataframes()
- self._sandbox_globals = {
- "employees_df": self.employees_df,
- "products_df": self.products_df,
- "orders_df": self.orders_df,
- "pd": pd,
- "plt": plt,
- "sns": sns,
- "np": np,
- }
- @tool
- def execute_python_code(code: str) -> str:
- """
- 执行 Python 代码进行数据分析和可视化。
- 可用变量:employees_df, products_df, orders_df, pd, plt, sns, np
- """
- exec_globals = dict(self._sandbox_globals)
- exec_locals = {}
- output_buffer = StringIO()
- try:
- with redirect_stdout(output_buffer):
- exec(code, exec_globals, exec_locals)
- result = output_buffer.getvalue()
- if not result.strip():
- result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
- return f"执行成功:\n{result}"
- except Exception as e:
- error_detail = traceback.format_exc()
- return f"❌ 执行出错:{e}\n\n{error_detail}"
- # ========== 3. 合并所有工具 ==========
- all_tools = sql_tools + [execute_python_code]
- print(f"\n✅ 工具合并完成,共 {len(all_tools)} 个工具:")
- for t in all_tools:
- print(f" - {t.name}")
- # ========== 4. 创建统一 Agent ==========
- self.together_agent = create_agent(
- model=self.llm,
- tools=all_tools,
- system_prompt=TOGETHER_SYSTEM_PROMPT,
- )
- print("\n🤖 统一分析 Agent 创建完成!")
- return self.together_agent
- # ============================================================
- # 统⼀ Agent 模式(作业)
- # ============================================================
- def together_analyze(self, user_query: str) -> str:
- """
- 统一分析入口:LLM 自主决定调用 SQL 工具或 Python 工具。
- Args:
- user_query: 用户的自然语言问题
- Returns:
- Agent 的分析结果
- """
- agent = self.get_together_agent()
- result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
- return result["messages"][-1].content
- # ============================================================
- # 使用示例
- # ============================================================
- if __name__ == "__main__":
- # 1. 创建系统实例
- system = DataAnalysisSystem()
- # 2. 测试 LLM 连接
- print("\n--- LLM 测试 ---")
- print(system.test_llm("你好,请用一句话介绍自己"))
- # 3. 初始化数据库(只需执行一次)
- # print("\n--- 初始化数据库 ---")
- # system.init_database()
- # 4. 运行分析
- # print("\n--- SQL 查询 ---")
- # result = system.query("技术部的平均工资是多少?")
- # print(result)
- #
- # print("\n--- 可视化分析 ---")
- # result = system.visualize("统计各部门人数并画柱状图")
- # print(result)
- # print("\n--- 自动路由 ---")
- # result = system.analyze("销售部谁的总销售额最高?")
- # print(result)
- print("\n--- 统一分析:LLM 自动选择工具 ---")
- result = system.together_analyze("统计各部门人数并画柱状图?")
- print(result)
|