import os import traceback from io import StringIO from contextlib import redirect_stdout import mysql.connector import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from dotenv import load_dotenv from langchain.tools import tool from langchain.agents import create_agent from langchain_openai import ChatOpenAI from langchain_community.utilities import SQLDatabase from langchain_community.agent_toolkits import SQLDatabaseToolkit from sqlalchemy import create_engine class DataAnalysisSystem: """ 数据分析系统:集成 NL2SQL、数据可视化、意图路由的统一分析平台。 支持硅基流动/DeepSeek 等大模型 API。 """ # ============================================================ # 类级常量:Prompt 模板 # ============================================================ SQL_AGENT_PROMPT = """你是一名专业的 SQL 数据分析师。 ## 工作流程 1. 先用 sql_db_list_tables 查看数据库中有哪些表 2. 用 sql_db_schema 获取相关表的字段结构和类型 3. 生成 SQL 之前,用 sql_db_query_checker 检查语法 4. 确认无误后,用 sql_db_query 执行查询 5. 用中文总结查询结果,给出简洁的业务洞察 ## 约束 - 只使用数据库中实际存在的表和字段,不要凭空编造 - 单次查询结果限制在 50 条以内 - 如果查询出错,分析错误原因后重新生成 SQL - 回答要简洁专业,不要啰嗦 """ VISUALIZATION_PROMPT = """你是一名资深数据分析师,精通 Python、Pandas 和 Matplotlib 数据可视化。 ## 可用数据 1. employees_df — 员工表(字段:id, name, department, salary, hire_date) 2. products_df — 产品表(字段:id, product_name, category, price, stock) 3. orders_df — 订单表(字段:id, employee_id, product_id, quantity, order_date) ## 工作流程 1. 理解用户的分析需求 2. 用 execute_python_code 工具编写并执行 Python 代码 3. 先做数据探索(head、describe、info),再做深入分析 4. 用中文解释分析结果,给出业务洞察 ## 代码规范 - 绑图前设置中文字体:plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans'] - 设置 plt.rcParams['axes.unicode_minus'] = False - 图表尺寸统一用 plt.figure(figsize=(10, 6)) - 必须添加标题、坐标轴标签,让图表自解释 - 用 print() 输出关键统计量,不要只画图不说话 - 图表标题用英文(避免渲染问题),但用中文向用户解释结果 ## 注意事项 - 每次只执行一段完整的代码,不要拆成多段 - 先探索数据结构,再做分析——不要上来就画图 - 结果要有业务洞察,不只是"最大值是 XXX" """ # ============================================================ # 构造函数 # ============================================================ def __init__(self, env_file: str = ".env"): """ 初始化数据分析系统。 Args: env_file: 环境变量文件路径,默认当前目录的 .env """ # 加载环境变量 load_dotenv(dotenv_path=env_file) # 初始化 LLM self.api_key = os.getenv("API_KEY") self.model_name = os.getenv("MODEL_NAME") self.base_url = os.getenv("BASE_URL") self.llm = ChatOpenAI( model=self.model_name, api_key=self.api_key, base_url=self.base_url, temperature=0, ) print("✅ 大语言模型初始化完成") # 构建数据库 URI self.db_uri = ( f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}" f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}" ) print("✅ 数据库 URI 构建完成") # 数据缓存 self.employees_df = None self.products_df = None self.orders_df = None # Agent 缓存(延迟初始化) self._sql_agent = None self._visual_agent = None self._sandbox_globals = None self.together_agent = None # ============================================================ # 公共方法:测试模型连接 # ============================================================ def test_llm(self, prompt: str = "什么是量子纠缠?") -> str: """测试 LLM 连接是否正常。""" response = self.llm.invoke(prompt) return response.content # ============================================================ # 公共方法:初始化数据库(建表 + 插数据) # ============================================================ def init_database(self) -> None: """ 创建演示数据库:employees、products、orders 三张表, 并插入示例数据。 """ conn = mysql.connector.connect( host=os.getenv("DB_HOST"), port=int(os.getenv("DB_PORT")), user=os.getenv("DB_USER"), password=os.getenv("DB_PASSWORD"), database=os.getenv("DB_NAME"), ) cursor = conn.cursor() # 建表 cursor.execute(""" CREATE TABLE IF NOT EXISTS employees ( id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(50) NOT NULL, department VARCHAR(50) NOT NULL, salary DECIMAL(10,2) NOT NULL, hire_date DATE ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """) cursor.execute(""" CREATE TABLE IF NOT EXISTS products ( id INT PRIMARY KEY AUTO_INCREMENT, product_name VARCHAR(100) NOT NULL, category VARCHAR(50) NOT NULL, price DECIMAL(10,2) NOT NULL, stock INT DEFAULT 0 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """) cursor.execute(""" CREATE TABLE IF NOT EXISTS orders ( id INT PRIMARY KEY AUTO_INCREMENT, employee_id INT NOT NULL, product_id INT NOT NULL, quantity INT NOT NULL, order_date DATE NOT NULL, FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (product_id) REFERENCES products(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """) # 插入数据 employees_data = [ (1, "张三", "技术部", 20000.00, "2023-01-15"), (2, "李四", "销售部", 11000.00, "2023-02-20"), (3, "王五", "技术部", 16000.00, "2022-11-10"), (4, "赵六", "人力资源", 5000.00, "2023-03-01"), (5, "钱七", "销售部", 17000.00, "2022-12-05"), ] products_data = [ (1, "笔记本电脑", "电子产品", 6999.00, 500), (2, "机械键盘", "电子产品", 399.00, 1000), (3, "办公椅", "办公用品", 499.00, 300), (4, "显示器", "电子产品", 1200.00, 400), ] orders_data = [ (1, 1, 1, 2, "2024-01-15"), (2, 2, 2, 15, "2024-01-16"), (3, 3, 1, 10, "2024-01-17"), (4, 5, 3, 6, "2024-01-18"), (5, 2, 4, 5, "2024-01-19"), ] cursor.executemany( "INSERT IGNORE INTO employees VALUES (%s,%s,%s,%s,%s)", employees_data ) cursor.executemany( "INSERT IGNORE INTO products VALUES (%s,%s,%s,%s,%s)", products_data ) cursor.executemany( "INSERT IGNORE INTO orders VALUES (%s,%s,%s,%s,%s)", orders_data ) conn.commit() conn.close() print("✅ 数据库初始化完成") print(" - employees 表:5 条员工记录") print(" - products 表:4 条产品记录") print(" - orders 表:5 条订单记录") # ============================================================ # 内部方法:获取 SQL Agent(延迟初始化) # ============================================================ def _get_sql_agent(self): """延迟初始化 SQL Agent。""" if self._sql_agent is not None: return self._sql_agent db = SQLDatabase.from_uri(self.db_uri) print(f"数据库连接成功,可用表:{db.get_usable_table_names()}") toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) tools = toolkit.get_tools() print(f"SQL 工具包已加载({len(tools)} 个工具):") for t in tools: print(f" - {t.name}") self._sql_agent = create_agent( model=self.llm, tools=tools, system_prompt=self.SQL_AGENT_PROMPT, ) print("NL2SQL Agent 创建完成") return self._sql_agent # ============================================================ # 内部方法:加载数据到 DataFrame(用于可视化) # ============================================================ def _load_dataframes(self): """从数据库加载数据到 Pandas DataFrame。""" if self.employees_df is not None: return # 已加载,跳过 # 配置 matplotlib 中文显示 plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"] plt.rcParams["axes.unicode_minus"] = False engine = create_engine(self.db_uri) self.employees_df = pd.read_sql("SELECT * FROM employees", engine) self.products_df = pd.read_sql("SELECT * FROM products", engine) self.orders_df = pd.read_sql("SELECT * FROM orders", engine) print("✅ 数据加载完成") print(f" 员工表:{len(self.employees_df)} 行") print(f" 产品表:{len(self.products_df)} 行") print(f" 订单表:{len(self.orders_df)} 行") # ============================================================ # 内部方法:获取可视化 Agent(延迟初始化) # ============================================================ def _get_visual_agent(self): """延迟初始化可视化 Agent。""" if self._visual_agent is not None: return self._visual_agent self._load_dataframes() # 构建沙箱全局变量 self._sandbox_globals = { "employees_df": self.employees_df, "products_df": self.products_df, "orders_df": self.orders_df, "pd": pd, "plt": plt, "sns": sns, "np": np, } # 创建工具 @tool def execute_python_code(code: str) -> str: """ 执行 Python 代码进行数据分析和可视化。 可用变量: - employees_df, products_df, orders_df - pd, plt, sns, np """ exec_globals = dict(self._sandbox_globals) exec_locals = {} output_buffer = StringIO() try: with redirect_stdout(output_buffer): exec(code, exec_globals, exec_locals) result = output_buffer.getvalue() if not result.strip(): result = "✅ 代码执行成功(无文本输出,可能已生成图表)" return f"执行成功:\n{result}" except Exception as e: error_detail = traceback.format_exc() return f"❌ 执行出错:{e}\n\n{error_detail}" viz_tools = [execute_python_code] self._visual_agent = create_agent( model=self.llm, tools=viz_tools, system_prompt=self.VISUALIZATION_PROMPT, ) print("✅ 数据可视化 Agent 创建完成") return self._visual_agent # ============================================================ # 公共方法:运行 SQL 查询 # ============================================================ def query(self, user_query: str) -> str: """ 使用 SQL Agent 执行自然语言到 SQL 的查询。 Args: user_query: 用户的自然语言问题 Returns: Agent 的分析结果文本 """ agent = self._get_sql_agent() result = agent.invoke({"messages": [{"role": "user", "content": user_query}]}) return result["messages"][-1].content # ============================================================ # 公共方法:运行可视化分析 # ============================================================ def visualize(self, user_query: str) -> str: """ 使用可视化 Agent 进行数据分析和图表生成。 Args: user_query: 用户的分析需求描述 Returns: Agent 的分析结果文本 """ agent = self._get_visual_agent() result = agent.invoke({"messages": [{"role": "user", "content": user_query}]}) return result["messages"][-1].content # ============================================================ # 公共方法:统一入口(自动意图路由) # ============================================================ def analyze(self, user_query: str) -> str: """ 统一分析入口:自动判断用户意图,路由到对应 Agent。 意图分类: - query:纯数据库查询 - visualize:纯数据可视化/统计分析 - both:先查询再可视化 Args: user_query: 用户的自然语言问题 Returns: 分析结果文本 """ classification_prompt = f"""判断以下用户问题属于哪种类型: - "query":需要查询数据库获取数据 - "visualize":需要画图或做统计分析 - "both":需要先查数据,再画图分析 用户问题:{user_query} 只回复一个词:query / visualize / both""" intent = self.llm.invoke(classification_prompt).content.strip().lower() if intent == "query": return self.query(user_query) elif intent == "visualize": return self.visualize(user_query) else: # both data_result = self.query(user_query) viz_input = f"基于以下数据进行可视化分析:\n{data_result}\n\n原始问题:{user_query}" return self.visualize(viz_input) # ============================================================ # 内部方法:获取或创建统一 Agent # ============================================================ def get_together_agent(self): TOGETHER_SYSTEM_PROMPT = """你是一名全能数据分析师,精通 SQL 查询和 Python 数据分析可视化。 ## 可用工具 1. **SQL 工具组**(数据库操作): - `sql_db_list_tables` — 查看数据库中有哪些表 - `sql_db_schema` — 获取指定表的字段结构 - `sql_db_query_checker` — 检查 SQL 语法是否正确 - `sql_db_query` — 执行 SQL 查询并返回结果 2. **Python 代码执行工具**(数据分析与可视化): - `execute_python_code` — 在沙箱中执行 Python 代码 ## 可用数据(Python 沙箱中已预加载) - `employees_df` — 员工表(id, name, department, salary, hire_date) - `products_df` — 产品表(id, product_name, category, price, stock) - `orders_df` — 订单表(id, employee_id, product_id, quantity, order_date) - `pd` (pandas), `plt` (matplotlib), `sns` (seaborn), `np` (numpy) ## 工作流程 1. **理解需求**:分析用户问题的核心诉求 2. **获取元数据**:如需查数据库,先用 `sql_db_list_tables` 和 `sql_db_schema` 了解表结构 3. **选择工具**: - 纯数据查询 → 用 SQL 工具 - 统计分析、画图 → 用 `execute_python_code` - 复杂分析 → 先用 SQL 查数据,再用 Python 做深度分析 4. **执行与验证**:SQL 先用 `sql_db_query_checker` 检查语法;Python 代码确保正确 5. **总结输出**:用中文给出简洁的业务洞察 ## 约束 - 只使用数据库中实际存在的表和字段 - SQL 单次查询结果限制在 50 条以内 - Python 代码绑图前设置中文字体:`plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']` - 设置 `plt.rcParams['axes.unicode_minus'] = False` - 图表尺寸用 `plt.figure(figsize=(10, 6))` - 必须添加标题、坐标轴标签 - 用 `print()` 输出关键统计量 - 回答简洁专业,不要啰嗦 """ """延迟初始化统一 Agent,合并 SQL 工具 + Python 执行工具。""" if self.together_agent is not None: return self.together_agent # ========== 1. SQL 工具组 ========== db = SQLDatabase.from_uri(self.db_uri) print(f"数据库连接成功,可用表:{db.get_usable_table_names()}") sql_toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) sql_tools = sql_toolkit.get_tools() print(f"SQL 工具加载完成({len(sql_tools)} 个):") for t in sql_tools: print(f" - {t.name}") # ========== 2. Python 代码执行工具 ========== self._load_dataframes() self._sandbox_globals = { "employees_df": self.employees_df, "products_df": self.products_df, "orders_df": self.orders_df, "pd": pd, "plt": plt, "sns": sns, "np": np, } @tool def execute_python_code(code: str) -> str: """ 执行 Python 代码进行数据分析和可视化。 可用变量:employees_df, products_df, orders_df, pd, plt, sns, np """ exec_globals = dict(self._sandbox_globals) exec_locals = {} output_buffer = StringIO() try: with redirect_stdout(output_buffer): exec(code, exec_globals, exec_locals) result = output_buffer.getvalue() if not result.strip(): result = "✅ 代码执行成功(无文本输出,可能已生成图表)" return f"执行成功:\n{result}" except Exception as e: error_detail = traceback.format_exc() return f"❌ 执行出错:{e}\n\n{error_detail}" # ========== 3. 合并所有工具 ========== all_tools = sql_tools + [execute_python_code] print(f"\n✅ 工具合并完成,共 {len(all_tools)} 个工具:") for t in all_tools: print(f" - {t.name}") # ========== 4. 创建统一 Agent ========== self.together_agent = create_agent( model=self.llm, tools=all_tools, system_prompt=TOGETHER_SYSTEM_PROMPT, ) print("\n🤖 统一分析 Agent 创建完成!") return self.together_agent # ============================================================ # 统⼀ Agent 模式(作业) # ============================================================ def together_analyze(self, user_query: str) -> str: """ 统一分析入口:LLM 自主决定调用 SQL 工具或 Python 工具。 Args: user_query: 用户的自然语言问题 Returns: Agent 的分析结果 """ agent = self.get_together_agent() result = agent.invoke({"messages": [{"role": "user", "content": user_query}]}) return result["messages"][-1].content # ============================================================ # 使用示例 # ============================================================ if __name__ == "__main__": # 1. 创建系统实例 system = DataAnalysisSystem() # 2. 测试 LLM 连接 print("\n--- LLM 测试 ---") print(system.test_llm("你好,请用一句话介绍自己")) # 3. 初始化数据库(只需执行一次) # print("\n--- 初始化数据库 ---") # system.init_database() # 4. 运行分析 # print("\n--- SQL 查询 ---") # result = system.query("技术部的平均工资是多少?") # print(result) # # print("\n--- 可视化分析 ---") # result = system.visualize("统计各部门人数并画柱状图") # print(result) # print("\n--- 自动路由 ---") # result = system.analyze("销售部谁的总销售额最高?") # print(result) print("\n--- 统一分析:LLM 自动选择工具 ---") result = system.together_analyze("统计各部门人数并画柱状图?") print(result)