|
|
@@ -1,549 +0,0 @@
|
|
|
-import os
|
|
|
-import traceback
|
|
|
-from io import StringIO
|
|
|
-from contextlib import redirect_stdout
|
|
|
-
|
|
|
-import mysql.connector
|
|
|
-import numpy as np
|
|
|
-import pandas as pd
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-import seaborn as sns
|
|
|
-from dotenv import load_dotenv
|
|
|
-from langchain.tools import tool
|
|
|
-from langchain.agents import create_agent
|
|
|
-from langchain_openai import ChatOpenAI
|
|
|
-from langchain_community.utilities import SQLDatabase
|
|
|
-from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
|
-from sqlalchemy import create_engine
|
|
|
-
|
|
|
-
|
|
|
-class DataAnalysisSystem:
|
|
|
- """
|
|
|
- 数据分析系统:集成 NL2SQL、数据可视化、意图路由的统一分析平台。
|
|
|
- 支持硅基流动/DeepSeek 等大模型 API。
|
|
|
- """
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 类级常量:Prompt 模板
|
|
|
- # ============================================================
|
|
|
- SQL_AGENT_PROMPT = """你是一名专业的 SQL 数据分析师。
|
|
|
-
|
|
|
- ## 工作流程
|
|
|
- 1. 先用 sql_db_list_tables 查看数据库中有哪些表
|
|
|
- 2. 用 sql_db_schema 获取相关表的字段结构和类型
|
|
|
- 3. 生成 SQL 之前,用 sql_db_query_checker 检查语法
|
|
|
- 4. 确认无误后,用 sql_db_query 执行查询
|
|
|
- 5. 用中文总结查询结果,给出简洁的业务洞察
|
|
|
-
|
|
|
- ## 约束
|
|
|
- - 只使用数据库中实际存在的表和字段,不要凭空编造
|
|
|
- - 单次查询结果限制在 50 条以内
|
|
|
- - 如果查询出错,分析错误原因后重新生成 SQL
|
|
|
- - 回答要简洁专业,不要啰嗦
|
|
|
- """
|
|
|
-
|
|
|
- VISUALIZATION_PROMPT = """你是一名资深数据分析师,精通 Python、Pandas 和 Matplotlib 数据可视化。
|
|
|
-
|
|
|
- ## 可用数据
|
|
|
- 1. employees_df — 员工表(字段:id, name, department, salary, hire_date)
|
|
|
- 2. products_df — 产品表(字段:id, product_name, category, price, stock)
|
|
|
- 3. orders_df — 订单表(字段:id, employee_id, product_id, quantity, order_date)
|
|
|
-
|
|
|
- ## 工作流程
|
|
|
- 1. 理解用户的分析需求
|
|
|
- 2. 用 execute_python_code 工具编写并执行 Python 代码
|
|
|
- 3. 先做数据探索(head、describe、info),再做深入分析
|
|
|
- 4. 用中文解释分析结果,给出业务洞察
|
|
|
-
|
|
|
- ## 代码规范
|
|
|
- - 绑图前设置中文字体:plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']
|
|
|
- - 设置 plt.rcParams['axes.unicode_minus'] = False
|
|
|
- - 图表尺寸统一用 plt.figure(figsize=(10, 6))
|
|
|
- - 必须添加标题、坐标轴标签,让图表自解释
|
|
|
- - 用 print() 输出关键统计量,不要只画图不说话
|
|
|
- - 图表标题用英文(避免渲染问题),但用中文向用户解释结果
|
|
|
-
|
|
|
- ## 注意事项
|
|
|
- - 每次只执行一段完整的代码,不要拆成多段
|
|
|
- - 先探索数据结构,再做分析——不要上来就画图
|
|
|
- - 结果要有业务洞察,不只是"最大值是 XXX"
|
|
|
- """
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 构造函数
|
|
|
- # ============================================================
|
|
|
- def __init__(self, env_file: str = ".env"):
|
|
|
- """
|
|
|
- 初始化数据分析系统。
|
|
|
-
|
|
|
- Args:
|
|
|
- env_file: 环境变量文件路径,默认当前目录的 .env
|
|
|
- """
|
|
|
- # 加载环境变量
|
|
|
- load_dotenv(dotenv_path=env_file)
|
|
|
-
|
|
|
- # 初始化 LLM
|
|
|
- self.api_key = os.getenv("API_KEY")
|
|
|
- self.model_name = os.getenv("MODEL_NAME")
|
|
|
- self.base_url = os.getenv("BASE_URL")
|
|
|
-
|
|
|
- self.llm = ChatOpenAI(
|
|
|
- model=self.model_name,
|
|
|
- api_key=self.api_key,
|
|
|
- base_url=self.base_url,
|
|
|
- temperature=0,
|
|
|
- )
|
|
|
- print("✅ 大语言模型初始化完成")
|
|
|
-
|
|
|
- # 构建数据库 URI
|
|
|
- self.db_uri = (
|
|
|
- f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
|
|
|
- f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
|
|
|
- )
|
|
|
- print("✅ 数据库 URI 构建完成")
|
|
|
-
|
|
|
- # 数据缓存
|
|
|
- self.employees_df = None
|
|
|
- self.products_df = None
|
|
|
- self.orders_df = None
|
|
|
-
|
|
|
- # Agent 缓存(延迟初始化)
|
|
|
- self._sql_agent = None
|
|
|
- self._visual_agent = None
|
|
|
- self._sandbox_globals = None
|
|
|
-
|
|
|
- self.together_agent = None
|
|
|
-
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 公共方法:测试模型连接
|
|
|
- # ============================================================
|
|
|
- def test_llm(self, prompt: str = "什么是量子纠缠?") -> str:
|
|
|
- """测试 LLM 连接是否正常。"""
|
|
|
- response = self.llm.invoke(prompt)
|
|
|
- return response.content
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 公共方法:初始化数据库(建表 + 插数据)
|
|
|
- # ============================================================
|
|
|
- def init_database(self) -> None:
|
|
|
- """
|
|
|
- 创建演示数据库:employees、products、orders 三张表,
|
|
|
- 并插入示例数据。
|
|
|
- """
|
|
|
- conn = mysql.connector.connect(
|
|
|
- host=os.getenv("DB_HOST"),
|
|
|
- port=int(os.getenv("DB_PORT")),
|
|
|
- user=os.getenv("DB_USER"),
|
|
|
- password=os.getenv("DB_PASSWORD"),
|
|
|
- database=os.getenv("DB_NAME"),
|
|
|
- )
|
|
|
- cursor = conn.cursor()
|
|
|
-
|
|
|
- # 建表
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS employees (
|
|
|
- id INT PRIMARY KEY AUTO_INCREMENT,
|
|
|
- name VARCHAR(50) NOT NULL,
|
|
|
- department VARCHAR(50) NOT NULL,
|
|
|
- salary DECIMAL(10,2) NOT NULL,
|
|
|
- hire_date DATE
|
|
|
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
- """)
|
|
|
-
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS products (
|
|
|
- id INT PRIMARY KEY AUTO_INCREMENT,
|
|
|
- product_name VARCHAR(100) NOT NULL,
|
|
|
- category VARCHAR(50) NOT NULL,
|
|
|
- price DECIMAL(10,2) NOT NULL,
|
|
|
- stock INT DEFAULT 0
|
|
|
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
- """)
|
|
|
-
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS orders (
|
|
|
- id INT PRIMARY KEY AUTO_INCREMENT,
|
|
|
- employee_id INT NOT NULL,
|
|
|
- product_id INT NOT NULL,
|
|
|
- quantity INT NOT NULL,
|
|
|
- order_date DATE NOT NULL,
|
|
|
- FOREIGN KEY (employee_id) REFERENCES employees(id),
|
|
|
- FOREIGN KEY (product_id) REFERENCES products(id)
|
|
|
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
- """)
|
|
|
-
|
|
|
- # 插入数据
|
|
|
- employees_data = [
|
|
|
- (1, "张三", "技术部", 20000.00, "2023-01-15"),
|
|
|
- (2, "李四", "销售部", 11000.00, "2023-02-20"),
|
|
|
- (3, "王五", "技术部", 16000.00, "2022-11-10"),
|
|
|
- (4, "赵六", "人力资源", 5000.00, "2023-03-01"),
|
|
|
- (5, "钱七", "销售部", 17000.00, "2022-12-05"),
|
|
|
- ]
|
|
|
- products_data = [
|
|
|
- (1, "笔记本电脑", "电子产品", 6999.00, 500),
|
|
|
- (2, "机械键盘", "电子产品", 399.00, 1000),
|
|
|
- (3, "办公椅", "办公用品", 499.00, 300),
|
|
|
- (4, "显示器", "电子产品", 1200.00, 400),
|
|
|
- ]
|
|
|
- orders_data = [
|
|
|
- (1, 1, 1, 2, "2024-01-15"),
|
|
|
- (2, 2, 2, 15, "2024-01-16"),
|
|
|
- (3, 3, 1, 10, "2024-01-17"),
|
|
|
- (4, 5, 3, 6, "2024-01-18"),
|
|
|
- (5, 2, 4, 5, "2024-01-19"),
|
|
|
- ]
|
|
|
-
|
|
|
- cursor.executemany(
|
|
|
- "INSERT IGNORE INTO employees VALUES (%s,%s,%s,%s,%s)", employees_data
|
|
|
- )
|
|
|
- cursor.executemany(
|
|
|
- "INSERT IGNORE INTO products VALUES (%s,%s,%s,%s,%s)", products_data
|
|
|
- )
|
|
|
- cursor.executemany(
|
|
|
- "INSERT IGNORE INTO orders VALUES (%s,%s,%s,%s,%s)", orders_data
|
|
|
- )
|
|
|
-
|
|
|
- conn.commit()
|
|
|
- conn.close()
|
|
|
-
|
|
|
- print("✅ 数据库初始化完成")
|
|
|
- print(" - employees 表:5 条员工记录")
|
|
|
- print(" - products 表:4 条产品记录")
|
|
|
- print(" - orders 表:5 条订单记录")
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 内部方法:获取 SQL Agent(延迟初始化)
|
|
|
- # ============================================================
|
|
|
- def _get_sql_agent(self):
|
|
|
- """延迟初始化 SQL Agent。"""
|
|
|
- if self._sql_agent is not None:
|
|
|
- return self._sql_agent
|
|
|
-
|
|
|
- db = SQLDatabase.from_uri(self.db_uri)
|
|
|
- print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
|
|
|
-
|
|
|
- toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
|
|
- tools = toolkit.get_tools()
|
|
|
-
|
|
|
- print(f"SQL 工具包已加载({len(tools)} 个工具):")
|
|
|
- for t in tools:
|
|
|
- print(f" - {t.name}")
|
|
|
-
|
|
|
- self._sql_agent = create_agent(
|
|
|
- model=self.llm,
|
|
|
- tools=tools,
|
|
|
- system_prompt=self.SQL_AGENT_PROMPT,
|
|
|
- )
|
|
|
- print("NL2SQL Agent 创建完成")
|
|
|
- return self._sql_agent
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 内部方法:加载数据到 DataFrame(用于可视化)
|
|
|
- # ============================================================
|
|
|
- def _load_dataframes(self):
|
|
|
- """从数据库加载数据到 Pandas DataFrame。"""
|
|
|
- if self.employees_df is not None:
|
|
|
- return # 已加载,跳过
|
|
|
-
|
|
|
- # 配置 matplotlib 中文显示
|
|
|
- plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
|
|
|
- plt.rcParams["axes.unicode_minus"] = False
|
|
|
-
|
|
|
- engine = create_engine(self.db_uri)
|
|
|
-
|
|
|
- self.employees_df = pd.read_sql("SELECT * FROM employees", engine)
|
|
|
- self.products_df = pd.read_sql("SELECT * FROM products", engine)
|
|
|
- self.orders_df = pd.read_sql("SELECT * FROM orders", engine)
|
|
|
-
|
|
|
- print("✅ 数据加载完成")
|
|
|
- print(f" 员工表:{len(self.employees_df)} 行")
|
|
|
- print(f" 产品表:{len(self.products_df)} 行")
|
|
|
- print(f" 订单表:{len(self.orders_df)} 行")
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 内部方法:获取可视化 Agent(延迟初始化)
|
|
|
- # ============================================================
|
|
|
- def _get_visual_agent(self):
|
|
|
- """延迟初始化可视化 Agent。"""
|
|
|
- if self._visual_agent is not None:
|
|
|
- return self._visual_agent
|
|
|
-
|
|
|
- self._load_dataframes()
|
|
|
-
|
|
|
- # 构建沙箱全局变量
|
|
|
- self._sandbox_globals = {
|
|
|
- "employees_df": self.employees_df,
|
|
|
- "products_df": self.products_df,
|
|
|
- "orders_df": self.orders_df,
|
|
|
- "pd": pd,
|
|
|
- "plt": plt,
|
|
|
- "sns": sns,
|
|
|
- "np": np,
|
|
|
- }
|
|
|
-
|
|
|
- # 创建工具
|
|
|
- @tool
|
|
|
- def execute_python_code(code: str) -> str:
|
|
|
- """
|
|
|
- 执行 Python 代码进行数据分析和可视化。
|
|
|
-
|
|
|
- 可用变量:
|
|
|
- - employees_df, products_df, orders_df
|
|
|
- - pd, plt, sns, np
|
|
|
- """
|
|
|
- exec_globals = dict(self._sandbox_globals)
|
|
|
- exec_locals = {}
|
|
|
- output_buffer = StringIO()
|
|
|
-
|
|
|
- try:
|
|
|
- with redirect_stdout(output_buffer):
|
|
|
- exec(code, exec_globals, exec_locals)
|
|
|
- result = output_buffer.getvalue()
|
|
|
- if not result.strip():
|
|
|
- result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
|
|
|
- return f"执行成功:\n{result}"
|
|
|
- except Exception as e:
|
|
|
- error_detail = traceback.format_exc()
|
|
|
- return f"❌ 执行出错:{e}\n\n{error_detail}"
|
|
|
-
|
|
|
- viz_tools = [execute_python_code]
|
|
|
-
|
|
|
- self._visual_agent = create_agent(
|
|
|
- model=self.llm,
|
|
|
- tools=viz_tools,
|
|
|
- system_prompt=self.VISUALIZATION_PROMPT,
|
|
|
- )
|
|
|
- print("✅ 数据可视化 Agent 创建完成")
|
|
|
- return self._visual_agent
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 公共方法:运行 SQL 查询
|
|
|
- # ============================================================
|
|
|
- def query(self, user_query: str) -> str:
|
|
|
- """
|
|
|
- 使用 SQL Agent 执行自然语言到 SQL 的查询。
|
|
|
-
|
|
|
- Args:
|
|
|
- user_query: 用户的自然语言问题
|
|
|
-
|
|
|
- Returns:
|
|
|
- Agent 的分析结果文本
|
|
|
- """
|
|
|
- agent = self._get_sql_agent()
|
|
|
- result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
|
|
|
- return result["messages"][-1].content
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 公共方法:运行可视化分析
|
|
|
- # ============================================================
|
|
|
- def visualize(self, user_query: str) -> str:
|
|
|
- """
|
|
|
- 使用可视化 Agent 进行数据分析和图表生成。
|
|
|
-
|
|
|
- Args:
|
|
|
- user_query: 用户的分析需求描述
|
|
|
-
|
|
|
- Returns:
|
|
|
- Agent 的分析结果文本
|
|
|
- """
|
|
|
- agent = self._get_visual_agent()
|
|
|
- result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
|
|
|
- return result["messages"][-1].content
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 公共方法:统一入口(自动意图路由)
|
|
|
- # ============================================================
|
|
|
- def analyze(self, user_query: str) -> str:
|
|
|
- """
|
|
|
- 统一分析入口:自动判断用户意图,路由到对应 Agent。
|
|
|
-
|
|
|
- 意图分类:
|
|
|
- - query:纯数据库查询
|
|
|
- - visualize:纯数据可视化/统计分析
|
|
|
- - both:先查询再可视化
|
|
|
-
|
|
|
- Args:
|
|
|
- user_query: 用户的自然语言问题
|
|
|
-
|
|
|
- Returns:
|
|
|
- 分析结果文本
|
|
|
- """
|
|
|
- classification_prompt = f"""判断以下用户问题属于哪种类型:
|
|
|
- - "query":需要查询数据库获取数据
|
|
|
- - "visualize":需要画图或做统计分析
|
|
|
- - "both":需要先查数据,再画图分析
|
|
|
-
|
|
|
- 用户问题:{user_query}
|
|
|
-
|
|
|
- 只回复一个词:query / visualize / both"""
|
|
|
-
|
|
|
- intent = self.llm.invoke(classification_prompt).content.strip().lower()
|
|
|
-
|
|
|
- if intent == "query":
|
|
|
- return self.query(user_query)
|
|
|
- elif intent == "visualize":
|
|
|
- return self.visualize(user_query)
|
|
|
- else: # both
|
|
|
- data_result = self.query(user_query)
|
|
|
- viz_input = f"基于以下数据进行可视化分析:\n{data_result}\n\n原始问题:{user_query}"
|
|
|
- return self.visualize(viz_input)
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 内部方法:获取或创建统一 Agent
|
|
|
- # ============================================================
|
|
|
- def get_together_agent(self):
|
|
|
- TOGETHER_SYSTEM_PROMPT = """你是一名全能数据分析师,精通 SQL 查询和 Python 数据分析可视化。
|
|
|
- ## 可用工具
|
|
|
- 1. **SQL 工具组**(数据库操作):
|
|
|
- - `sql_db_list_tables` — 查看数据库中有哪些表
|
|
|
- - `sql_db_schema` — 获取指定表的字段结构
|
|
|
- - `sql_db_query_checker` — 检查 SQL 语法是否正确
|
|
|
- - `sql_db_query` — 执行 SQL 查询并返回结果
|
|
|
-
|
|
|
- 2. **Python 代码执行工具**(数据分析与可视化):
|
|
|
- - `execute_python_code` — 在沙箱中执行 Python 代码
|
|
|
-
|
|
|
- ## 可用数据(Python 沙箱中已预加载)
|
|
|
- - `employees_df` — 员工表(id, name, department, salary, hire_date)
|
|
|
- - `products_df` — 产品表(id, product_name, category, price, stock)
|
|
|
- - `orders_df` — 订单表(id, employee_id, product_id, quantity, order_date)
|
|
|
- - `pd` (pandas), `plt` (matplotlib), `sns` (seaborn), `np` (numpy)
|
|
|
-
|
|
|
- ## 工作流程
|
|
|
- 1. **理解需求**:分析用户问题的核心诉求
|
|
|
- 2. **获取元数据**:如需查数据库,先用 `sql_db_list_tables` 和 `sql_db_schema` 了解表结构
|
|
|
- 3. **选择工具**:
|
|
|
- - 纯数据查询 → 用 SQL 工具
|
|
|
- - 统计分析、画图 → 用 `execute_python_code`
|
|
|
- - 复杂分析 → 先用 SQL 查数据,再用 Python 做深度分析
|
|
|
- 4. **执行与验证**:SQL 先用 `sql_db_query_checker` 检查语法;Python 代码确保正确
|
|
|
- 5. **总结输出**:用中文给出简洁的业务洞察
|
|
|
-
|
|
|
- ## 约束
|
|
|
- - 只使用数据库中实际存在的表和字段
|
|
|
- - SQL 单次查询结果限制在 50 条以内
|
|
|
- - Python 代码绑图前设置中文字体:`plt.rcParams['font.sans-serif'] = ['SimHei', 'PingFang SC', 'DejaVu Sans']`
|
|
|
- - 设置 `plt.rcParams['axes.unicode_minus'] = False`
|
|
|
- - 图表尺寸用 `plt.figure(figsize=(10, 6))`
|
|
|
- - 必须添加标题、坐标轴标签
|
|
|
- - 用 `print()` 输出关键统计量
|
|
|
- - 回答简洁专业,不要啰嗦
|
|
|
- """
|
|
|
- """延迟初始化统一 Agent,合并 SQL 工具 + Python 执行工具。"""
|
|
|
- if self.together_agent is not None:
|
|
|
- return self.together_agent
|
|
|
-
|
|
|
- # ========== 1. SQL 工具组 ==========
|
|
|
- db = SQLDatabase.from_uri(self.db_uri)
|
|
|
- print(f"数据库连接成功,可用表:{db.get_usable_table_names()}")
|
|
|
-
|
|
|
- sql_toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
|
|
- sql_tools = sql_toolkit.get_tools()
|
|
|
-
|
|
|
- print(f"SQL 工具加载完成({len(sql_tools)} 个):")
|
|
|
- for t in sql_tools:
|
|
|
- print(f" - {t.name}")
|
|
|
-
|
|
|
- # ========== 2. Python 代码执行工具 ==========
|
|
|
- self._load_dataframes()
|
|
|
-
|
|
|
- self._sandbox_globals = {
|
|
|
- "employees_df": self.employees_df,
|
|
|
- "products_df": self.products_df,
|
|
|
- "orders_df": self.orders_df,
|
|
|
- "pd": pd,
|
|
|
- "plt": plt,
|
|
|
- "sns": sns,
|
|
|
- "np": np,
|
|
|
- }
|
|
|
-
|
|
|
- @tool
|
|
|
- def execute_python_code(code: str) -> str:
|
|
|
- """
|
|
|
- 执行 Python 代码进行数据分析和可视化。
|
|
|
-
|
|
|
- 可用变量:employees_df, products_df, orders_df, pd, plt, sns, np
|
|
|
- """
|
|
|
- exec_globals = dict(self._sandbox_globals)
|
|
|
- exec_locals = {}
|
|
|
- output_buffer = StringIO()
|
|
|
-
|
|
|
- try:
|
|
|
- with redirect_stdout(output_buffer):
|
|
|
- exec(code, exec_globals, exec_locals)
|
|
|
- result = output_buffer.getvalue()
|
|
|
- if not result.strip():
|
|
|
- result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
|
|
|
- return f"执行成功:\n{result}"
|
|
|
- except Exception as e:
|
|
|
- error_detail = traceback.format_exc()
|
|
|
- return f"❌ 执行出错:{e}\n\n{error_detail}"
|
|
|
-
|
|
|
- # ========== 3. 合并所有工具 ==========
|
|
|
- all_tools = sql_tools + [execute_python_code]
|
|
|
-
|
|
|
- print(f"\n✅ 工具合并完成,共 {len(all_tools)} 个工具:")
|
|
|
- for t in all_tools:
|
|
|
- print(f" - {t.name}")
|
|
|
-
|
|
|
- # ========== 4. 创建统一 Agent ==========
|
|
|
- self.together_agent = create_agent(
|
|
|
- model=self.llm,
|
|
|
- tools=all_tools,
|
|
|
- system_prompt=TOGETHER_SYSTEM_PROMPT,
|
|
|
- )
|
|
|
- print("\n🤖 统一分析 Agent 创建完成!")
|
|
|
- return self.together_agent
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 统⼀ Agent 模式(作业)
|
|
|
- # ============================================================
|
|
|
- def together_analyze(self, user_query: str) -> str:
|
|
|
- """
|
|
|
- 统一分析入口:LLM 自主决定调用 SQL 工具或 Python 工具。
|
|
|
-
|
|
|
- Args:
|
|
|
- user_query: 用户的自然语言问题
|
|
|
-
|
|
|
- Returns:
|
|
|
- Agent 的分析结果
|
|
|
- """
|
|
|
- agent = self.get_together_agent()
|
|
|
- result = agent.invoke({"messages": [{"role": "user", "content": user_query}]})
|
|
|
- return result["messages"][-1].content
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-# ============================================================
|
|
|
-# 使用示例
|
|
|
-# ============================================================
|
|
|
-if __name__ == "__main__":
|
|
|
- # 1. 创建系统实例
|
|
|
- system = DataAnalysisSystem()
|
|
|
-
|
|
|
- # 2. 测试 LLM 连接
|
|
|
- print("\n--- LLM 测试 ---")
|
|
|
- print(system.test_llm("你好,请用一句话介绍自己"))
|
|
|
-
|
|
|
- # 3. 初始化数据库(只需执行一次)
|
|
|
- # print("\n--- 初始化数据库 ---")
|
|
|
- # system.init_database()
|
|
|
-
|
|
|
- # 4. 运行分析
|
|
|
- # print("\n--- SQL 查询 ---")
|
|
|
- # result = system.query("技术部的平均工资是多少?")
|
|
|
- # print(result)
|
|
|
- #
|
|
|
- # print("\n--- 可视化分析 ---")
|
|
|
- # result = system.visualize("统计各部门人数并画柱状图")
|
|
|
- # print(result)
|
|
|
-
|
|
|
- # print("\n--- 自动路由 ---")
|
|
|
- # result = system.analyze("销售部谁的总销售额最高?")
|
|
|
- # print(result)
|
|
|
-
|
|
|
- print("\n--- 统一分析:LLM 自动选择工具 ---")
|
|
|
- result = system.together_analyze("统计各部门人数并画柱状图?")
|
|
|
- print(result)
|