Browse Source

上传文件至 ''

01_data_analysis
zhouche 13 hours ago
parent
commit
65aeec5a3f
1 changed files with 549 additions and 0 deletions
  1. 549 0
      统⼀ Agent 模式.py

+ 549 - 0
统⼀ Agent 模式.py

@@ -0,0 +1,549 @@
+import os
+import traceback
+from io import StringIO
+from contextlib import redirect_stdout
+
+import mysql.connector
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+from dotenv import load_dotenv
+from langchain.tools import tool
+from langchain.agents import create_agent
+from langchain_openai import ChatOpenAI
+from langchain_community.utilities import SQLDatabase
+from langchain_community.agent_toolkits import SQLDatabaseToolkit
+from sqlalchemy import create_engine
+
+
+class DataAnalysisSystem:
+    """
+    数据分析系统:集成 NL2SQL、数据可视化、意图路由的统一分析平台。
+    支持硅基流动/DeepSeek 等大模型 API。
+    """
+
+    # ============================================================
+    # 类级常量:Prompt 模板
+    # ============================================================
+    SQL_AGENT_PROMPT = """你是一名专业的 SQL 数据分析师。
+
+        ## 工作流程
+        1. 先用 sql_db_list_tables 查看数据库中有哪些表
+        2. 用 sql_db_schema 获取相关表的字段结构和类型
+        3. 生成 SQL 之前,用 sql_db_query_checker 检查语法
+        4. 确认无误后,用 sql_db_query 执行查询
+        5. 用中文总结查询结果,给出简洁的业务洞察
+        
+        ## 约束
+        - 只使用数据库中实际存在的表和字段,不要凭空编造
+        - 单次查询结果限制在 50 条以内
+        - 如果查询出错,分析错误原因后重新生成 SQL
+        - 回答要简洁专业,不要啰嗦
+        """
+
+    VISUALIZATION_PROMPT = """你是一名资深数据分析师,精通 Python、Pandas 和 Matplotlib 数据可视化。
+
+        ## 可用数据
+        1. employees_df — 员工表(字段:id, name, department, salary, hire_date)
+        2. products_df  — 产品表(字段:id, product_name, category, price, stock)
+        3. orders_df    — 订单表(字段:id, employee_id, product_id, quantity, order_date)
+        
+        ## 工作流程
+        1. 理解用户的分析需求
+        2. 用 execute_python_code 工具编写并执行 Python 代码
+        3. 先做数据探索(head、describe、info),再做深入分析
+        4. 用中文解释分析结果,给出业务洞察
+
+        ## 代码规范
+        - 绑图前设置中文字体:plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']
+        - 设置 plt.rcParams['axes.unicode_minus'] = False
+        - 图表尺寸统一用 plt.figure(figsize=(10, 6))
+        - 必须添加标题、坐标轴标签,让图表自解释
+        - 用 print() 输出关键统计量,不要只画图不说话
+        - 图表标题用英文(避免渲染问题),但用中文向用户解释结果
+        
+        ## 注意事项
+        - 每次只执行一段完整的代码,不要拆成多段
+        - 先探索数据结构,再做分析——不要上来就画图
+        - 结果要有业务洞察,不只是"最大值是 XXX"
+        """
+
+    # ============================================================
+    # 构造函数
+    # ============================================================
+    def __init__(self, env_file: str = ".env"):
+        """
+        初始化数据分析系统。
+
+        Args:
+            env_file: 环境变量文件路径,默认当前目录的 .env
+        """
+        # 加载环境变量
+        load_dotenv(dotenv_path=env_file)
+
+        # 初始化 LLM
+        self.api_key = os.getenv("API_KEY")
+        self.model_name = os.getenv("MODEL_NAME")
+        self.base_url = os.getenv("BASE_URL")
+
+        self.llm = ChatOpenAI(
+            model=self.model_name,
+            api_key=self.api_key,
+            base_url=self.base_url,
+            temperature=0,
+        )
+        print("✅ 大语言模型初始化完成")
+
+        # 构建数据库 URI
+        self.db_uri = (
+            f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
+            f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
+        )
+        print("✅ 数据库 URI 构建完成")
+
+        # 数据缓存
+        self.employees_df = None
+        self.products_df = None
+        self.orders_df = None
+
+        # Agent 缓存(延迟初始化)
+        self._sql_agent = None
+        self._visual_agent = None
+        self._sandbox_globals = None
+
+        self.together_agent = None
+
+
+    # ============================================================
+    # 公共方法:测试模型连接
+    # ============================================================
+    def test_llm(self, prompt: str = "什么是量子纠缠?") -> str:
+        """测试 LLM 连接是否正常。"""
+        response = self.llm.invoke(prompt)
+        return response.content
+
+    # ============================================================
+    # 公共方法:初始化数据库(建表 + 插数据)
+    # ============================================================
+    def init_database(self) -> None:
+        """
+        创建演示数据库:employees、products、orders 三张表,
+        并插入示例数据。
+        """
+        conn = mysql.connector.connect(
+            host=os.getenv("DB_HOST"),
+            port=int(os.getenv("DB_PORT")),
+            user=os.getenv("DB_USER"),
+            password=os.getenv("DB_PASSWORD"),
+            database=os.getenv("DB_NAME"),
+        )
+        cursor = conn.cursor()
+
+        # 建表
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS employees (
+            id          INT PRIMARY KEY AUTO_INCREMENT,
+            name        VARCHAR(50)  NOT NULL,
+            department  VARCHAR(50)  NOT NULL,
+            salary      DECIMAL(10,2) NOT NULL,
+            hire_date   DATE
+        ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+        """)
+
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS products (
+            id           INT PRIMARY KEY AUTO_INCREMENT,
+            product_name VARCHAR(100) NOT NULL,
+            category     VARCHAR(50)  NOT NULL,
+            price        DECIMAL(10,2) NOT NULL,
+            stock        INT DEFAULT 0
+        ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+        """)
+
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS orders (
+            id           INT PRIMARY KEY AUTO_INCREMENT,
+            employee_id  INT NOT NULL,
+            product_id   INT NOT NULL,
+            quantity     INT NOT NULL,
+            order_date   DATE NOT NULL,
+            FOREIGN KEY (employee_id) REFERENCES employees(id),
+            FOREIGN KEY (product_id)  REFERENCES products(id)
+        ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+        """)
+
+        # 插入数据
+        employees_data = [
+            (1, "张三", "技术部", 20000.00, "2023-01-15"),
+            (2, "李四", "销售部", 11000.00, "2023-02-20"),
+            (3, "王五", "技术部", 16000.00, "2022-11-10"),
+            (4, "赵六", "人力资源", 5000.00, "2023-03-01"),
+            (5, "钱七", "销售部", 17000.00, "2022-12-05"),
+        ]
+        products_data = [
+            (1, "笔记本电脑", "电子产品", 6999.00, 500),
+            (2, "机械键盘", "电子产品", 399.00, 1000),
+            (3, "办公椅", "办公用品", 499.00, 300),
+            (4, "显示器", "电子产品", 1200.00, 400),
+        ]
+        orders_data = [
+            (1, 1, 1, 2, "2024-01-15"),
+            (2, 2, 2, 15, "2024-01-16"),
+            (3, 3, 1, 10, "2024-01-17"),
+            (4, 5, 3, 6, "2024-01-18"),
+            (5, 2, 4, 5, "2024-01-19"),
+        ]
+
+        cursor.executemany(
+            "INSERT IGNORE INTO employees VALUES (%s,%s,%s,%s,%s)", employees_data
+        )
+        cursor.executemany(
+            "INSERT IGNORE INTO products  VALUES (%s,%s,%s,%s,%s)", products_data
+        )
+        cursor.executemany(
+            "INSERT IGNORE INTO orders     VALUES (%s,%s,%s,%s,%s)", orders_data
+        )
+
+        conn.commit()
+        conn.close()
+
+        print("✅ 数据库初始化完成")
+        print("   - employees 表:5 条员工记录")
+        print("   - products  表:4 条产品记录")
+        print("   - orders    表:5 条订单记录")
+
+    # ============================================================
+    # 内部方法:获取 SQL Agent(延迟初始化)
+    # ============================================================
+    def _get_sql_agent(self):
+        """延迟初始化 SQL Agent。"""
+        if self._sql_agent is not None:
+            return self._sql_agent
+
+        db = SQLDatabase.from_uri(self.db_uri)
+        print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
+
+        toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
+        tools = toolkit.get_tools()
+
+        print(f"SQL 工具包已加载({len(tools)} 个工具):")
+        for t in tools:
+            print(f"   - {t.name}")
+
+        self._sql_agent = create_agent(
+            model=self.llm,
+            tools=tools,
+            system_prompt=self.SQL_AGENT_PROMPT,
+        )
+        print("NL2SQL Agent 创建完成")
+        return self._sql_agent
+
+    # ============================================================
+    # 内部方法:加载数据到 DataFrame(用于可视化)
+    # ============================================================
+    def _load_dataframes(self):
+        """从数据库加载数据到 Pandas DataFrame。"""
+        if self.employees_df is not None:
+            return  # 已加载,跳过
+
+        # 配置 matplotlib 中文显示
+        plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
+        plt.rcParams["axes.unicode_minus"] = False
+
+        engine = create_engine(self.db_uri)
+
+        self.employees_df = pd.read_sql("SELECT * FROM employees", engine)
+        self.products_df = pd.read_sql("SELECT * FROM products", engine)
+        self.orders_df = pd.read_sql("SELECT * FROM orders", engine)
+
+        print("✅ 数据加载完成")
+        print(f"   员工表:{len(self.employees_df)} 行")
+        print(f"   产品表:{len(self.products_df)} 行")
+        print(f"   订单表:{len(self.orders_df)} 行")
+
+    # ============================================================
+    # 内部方法:获取可视化 Agent(延迟初始化)
+    # ============================================================
+    def _get_visual_agent(self):
+        """延迟初始化可视化 Agent。"""
+        if self._visual_agent is not None:
+            return self._visual_agent
+
+        self._load_dataframes()
+
+        # 构建沙箱全局变量
+        self._sandbox_globals = {
+            "employees_df": self.employees_df,
+            "products_df": self.products_df,
+            "orders_df": self.orders_df,
+            "pd": pd,
+            "plt": plt,
+            "sns": sns,
+            "np": np,
+        }
+
+        # 创建工具
+        @tool
+        def execute_python_code(code: str) -> str:
+            """
+            执行 Python 代码进行数据分析和可视化。
+
+            可用变量:
+              - employees_df, products_df, orders_df
+              - pd, plt, sns, np
+            """
+            exec_globals = dict(self._sandbox_globals)
+            exec_locals = {}
+            output_buffer = StringIO()
+
+            try:
+                with redirect_stdout(output_buffer):
+                    exec(code, exec_globals, exec_locals)
+                result = output_buffer.getvalue()
+                if not result.strip():
+                    result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
+                return f"执行成功:\n{result}"
+            except Exception as e:
+                error_detail = traceback.format_exc()
+                return f"❌ 执行出错:{e}\n\n{error_detail}"
+
+        viz_tools = [execute_python_code]
+
+        self._visual_agent = create_agent(
+            model=self.llm,
+            tools=viz_tools,
+            system_prompt=self.VISUALIZATION_PROMPT,
+        )
+        print("✅ 数据可视化 Agent 创建完成")
+        return self._visual_agent
+
+    # ============================================================
+    # 公共方法:运行 SQL 查询
+    # ============================================================
+    def query(self, user_query: str) -> str:
+        """
+        使用 SQL Agent 执行自然语言到 SQL 的查询。
+
+        Args:
+            user_query: 用户的自然语言问题
+
+        Returns:
+            Agent 的分析结果文本
+        """
+        agent = self._get_sql_agent()
+        result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
+        return result["messages"][-1].content
+
+    # ============================================================
+    # 公共方法:运行可视化分析
+    # ============================================================
+    def visualize(self, user_query: str) -> str:
+        """
+        使用可视化 Agent 进行数据分析和图表生成。
+
+        Args:
+            user_query: 用户的分析需求描述
+
+        Returns:
+            Agent 的分析结果文本
+        """
+        agent = self._get_visual_agent()
+        result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
+        return result["messages"][-1].content
+
+    # ============================================================
+    # 公共方法:统一入口(自动意图路由)
+    # ============================================================
+    def analyze(self, user_query: str) -> str:
+        """
+        统一分析入口:自动判断用户意图,路由到对应 Agent。
+
+        意图分类:
+          - query:纯数据库查询
+          - visualize:纯数据可视化/统计分析
+          - both:先查询再可视化
+
+        Args:
+            user_query: 用户的自然语言问题
+
+        Returns:
+            分析结果文本
+        """
+        classification_prompt = f"""判断以下用户问题属于哪种类型:
+            - "query":需要查询数据库获取数据
+            - "visualize":需要画图或做统计分析
+            - "both":需要先查数据,再画图分析
+
+            用户问题:{user_query}
+
+            只回复一个词:query / visualize / both"""
+
+        intent = self.llm.invoke(classification_prompt).content.strip().lower()
+
+        if intent == "query":
+            return self.query(user_query)
+        elif intent == "visualize":
+            return self.visualize(user_query)
+        else:  # both
+            data_result = self.query(user_query)
+            viz_input = f"基于以下数据进行可视化分析:\n{data_result}\n\n原始问题:{user_query}"
+            return self.visualize(viz_input)
+
+    # ============================================================
+    # 内部方法:获取或创建统一 Agent
+    # ============================================================
+    def get_together_agent(self):
+        TOGETHER_SYSTEM_PROMPT = """你是一名全能数据分析师,精通 SQL 查询和 Python 数据分析可视化。
+               ## 可用工具
+               1. **SQL 工具组**(数据库操作):
+                  - `sql_db_list_tables` — 查看数据库中有哪些表
+                  - `sql_db_schema` — 获取指定表的字段结构
+                  - `sql_db_query_checker` — 检查 SQL 语法是否正确
+                  - `sql_db_query` — 执行 SQL 查询并返回结果
+
+               2. **Python 代码执行工具**(数据分析与可视化):
+                  - `execute_python_code` — 在沙箱中执行 Python 代码
+
+               ## 可用数据(Python 沙箱中已预加载)
+               - `employees_df` — 员工表(id, name, department, salary, hire_date)
+               - `products_df` — 产品表(id, product_name, category, price, stock)
+               - `orders_df` — 订单表(id, employee_id, product_id, quantity, order_date)
+               - `pd` (pandas), `plt` (matplotlib), `sns` (seaborn), `np` (numpy)
+
+               ## 工作流程
+               1. **理解需求**:分析用户问题的核心诉求
+               2. **获取元数据**:如需查数据库,先用 `sql_db_list_tables` 和 `sql_db_schema` 了解表结构
+               3. **选择工具**:
+                  - 纯数据查询 → 用 SQL 工具
+                  - 统计分析、画图 → 用 `execute_python_code`
+                  - 复杂分析 → 先用 SQL 查数据,再用 Python 做深度分析
+               4. **执行与验证**:SQL 先用 `sql_db_query_checker` 检查语法;Python 代码确保正确
+               5. **总结输出**:用中文给出简洁的业务洞察
+
+               ## 约束
+               - 只使用数据库中实际存在的表和字段
+               - SQL 单次查询结果限制在 50 条以内
+               - Python 代码绑图前设置中文字体:`plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']`
+               - 设置 `plt.rcParams['axes.unicode_minus'] = False`
+               - 图表尺寸用 `plt.figure(figsize=(10, 6))`
+               - 必须添加标题、坐标轴标签
+               - 用 `print()` 输出关键统计量
+               - 回答简洁专业,不要啰嗦
+               """
+        """延迟初始化统一 Agent,合并 SQL 工具 + Python 执行工具。"""
+        if self.together_agent is not None:
+            return self.together_agent
+
+        # ========== 1. SQL 工具组 ==========
+        db = SQLDatabase.from_uri(self.db_uri)
+        print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
+
+        sql_toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
+        sql_tools = sql_toolkit.get_tools()
+
+        print(f"SQL 工具加载完成({len(sql_tools)} 个):")
+        for t in sql_tools:
+            print(f"   - {t.name}")
+
+        # ========== 2. Python 代码执行工具 ==========
+        self._load_dataframes()
+
+        self._sandbox_globals = {
+            "employees_df": self.employees_df,
+            "products_df": self.products_df,
+            "orders_df": self.orders_df,
+            "pd": pd,
+            "plt": plt,
+            "sns": sns,
+            "np": np,
+        }
+
+        @tool
+        def execute_python_code(code: str) -> str:
+            """
+            执行 Python 代码进行数据分析和可视化。
+
+            可用变量:employees_df, products_df, orders_df, pd, plt, sns, np
+            """
+            exec_globals = dict(self._sandbox_globals)
+            exec_locals = {}
+            output_buffer = StringIO()
+
+            try:
+                with redirect_stdout(output_buffer):
+                    exec(code, exec_globals, exec_locals)
+                result = output_buffer.getvalue()
+                if not result.strip():
+                    result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
+                return f"执行成功:\n{result}"
+            except Exception as e:
+                error_detail = traceback.format_exc()
+                return f"❌ 执行出错:{e}\n\n{error_detail}"
+
+        # ========== 3. 合并所有工具 ==========
+        all_tools = sql_tools + [execute_python_code]
+
+        print(f"\n✅ 工具合并完成,共 {len(all_tools)} 个工具:")
+        for t in all_tools:
+            print(f"   - {t.name}")
+
+        # ========== 4. 创建统一 Agent ==========
+        self.together_agent = create_agent(
+            model=self.llm,
+            tools=all_tools,
+            system_prompt=TOGETHER_SYSTEM_PROMPT,
+        )
+        print("\n🤖 统一分析 Agent 创建完成!")
+        return self.together_agent
+
+    # ============================================================
+    # 统⼀ Agent 模式(作业)
+    # ============================================================
+    def together_analyze(self, user_query: str) -> str:
+        """
+        统一分析入口:LLM 自主决定调用 SQL 工具或 Python 工具。
+
+        Args:
+            user_query: 用户的自然语言问题
+
+        Returns:
+            Agent 的分析结果
+        """
+        agent = self.get_together_agent()
+        result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
+        return result["messages"][-1].content
+
+
+
+# ============================================================
+# 使用示例
+# ============================================================
+if __name__ == "__main__":
+    # 1. 创建系统实例
+    system = DataAnalysisSystem()
+
+    # 2. 测试 LLM 连接
+    print("\n--- LLM 测试 ---")
+    print(system.test_llm("你好,请用一句话介绍自己"))
+
+    # 3. 初始化数据库(只需执行一次)
+    # print("\n--- 初始化数据库 ---")
+    # system.init_database()
+
+    # 4. 运行分析
+    # print("\n--- SQL 查询 ---")
+    # result = system.query("技术部的平均工资是多少?")
+    # print(result)
+    #
+    # print("\n--- 可视化分析 ---")
+    # result = system.visualize("统计各部门人数并画柱状图")
+    # print(result)
+
+    # print("\n--- 自动路由 ---")
+    # result = system.analyze("销售部谁的总销售额最高?")
+    # print(result)
+
+    print("\n--- 统一分析:LLM 自动选择工具 ---")
+    result = system.together_analyze("统计各部门人数并画柱状图?")
+    print(result)