QIN 5 tuntia sitten
vanhempi
commit
99405747a8
2 muutettua tiedostoa jossa 271 lisäystä ja 0 poistoa
  1. 231 0
      01_data_analysis.py
  2. 40 0
      sql_agent.ipynb

+ 231 - 0
01_data_analysis.py

@@ -0,0 +1,231 @@
+import os
+from langchain_community.utilities import SQLDatabase
+from langchain_community.agent_toolkits import SQLDatabaseToolkit
+from langchain_openai import ChatOpenAI
+from langchain.agents import create_agent  # langchain 1.3.1 的新 API
+
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+import numpy as np
+
+import traceback
+from io import StringIO
+from contextlib import redirect_stdout
+from langchain.tools import tool
+
+from dotenv import load_dotenv 
+
+load_dotenv()  # 默认加载当前目录或父目录中的 .env 文件
+# ============================================================
+# 第一步:连接数据库
+# ============================================================
+# SQLDatabase.from_uri 接受标准的数据库连接 URI
+# LangChain 内部会用 SQLAlchemy 管理连接池
+db_uri = (
+    f"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
+    f"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
+)
+db = SQLDatabase.from_uri(db_uri)
+
+# 验证连接:打印可用的表名
+print(f"数据库连接成功")
+print(f"   可用表:{db.get_usable_table_names()}")
+
+# ============================================================
+# 第二步:初始化大模型
+# ============================================================
+llm = ChatOpenAI(
+    model=os.getenv("DEEPSEEK_MODEL"),
+    api_key=os.getenv("DEEPSEEK_API_KEY"),
+    base_url=os.getenv("DEEPSEEK_BASE_URL"),
+    temperature=0,  # SQL 生成需要确定性输出
+)
+
+# ============================================================
+# 第三步:创建 SQL 工具包
+# ============================================================
+# SQLDatabaseToolkit 会自动注册 4 个工具:
+#   1. sql_db_list_tables    — 列出数据库中所有表
+#   2. sql_db_schema         — 获取指定表的 DDL 结构
+#   3. sql_db_query_checker  — 检查 SQL 语法是否正确
+#   4. sql_db_query          — 执行 SQL 并返回结果
+toolkit = SQLDatabaseToolkit(db=db, llm=llm)
+tools = toolkit.get_tools()
+
+print(f"\nSQL 工具包已加载({len(tools)} 个工具):")
+for t in tools:
+    print(f"   - {t.name}")
+
+# ============================================================
+# 配置 matplotlib 中文显示
+# ============================================================
+# Windows 用 SimHei(黑体),macOS 用 PingFang SC
+# 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存
+plt.rcParams["font.sans-serif"] = ["SimHei", "PingFang SC", "DejaVu Sans"]
+plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示为方块的问题
+
+# ============================================================
+# 从数据库加载数据到 Pandas DataFrame
+# ============================================================
+# 用 SQLAlchemy engine 复用连接,避免重复创建连接池
+from sqlalchemy import create_engine
+
+engine = create_engine(db_uri)
+
+employees_df = pd.read_sql("SELECT * FROM employees", engine)
+products_df  = pd.read_sql("SELECT * FROM products", engine)
+orders_df    = pd.read_sql("SELECT * FROM orders", engine)
+
+print("✅ 数据加载完成")
+print(f"   员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列")
+print(f"   产品表:{len(products_df)} 行 × {len(products_df.columns)} 列")
+print(f"   订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列")
+
+# ============================================================
+# 定义沙箱的"白名单"——只有这些库和数据可以被代码访问
+# ============================================================
+# 这是一种简单的安全策略:不在白名单里的东西,代码碰不到
+SANDBOX_GLOBALS = {
+    # 数据:Agent 可以分析这三张表
+    "employees_df": employees_df,
+    "products_df":  products_df,
+    "orders_df":    orders_df,
+    # 工具库:Agent 可以用这些库做分析和画图
+    "pd":  pd,       # pandas — 数据处理
+    "plt": plt,      # matplotlib — 基础绑图
+    "sns": sns,      # seaborn — 统计可视化
+    "np":  np,       # numpy — 数值计算
+}
+
+
+@tool
+def execute_python_code(code: str) -> str:
+    """
+    执行 Python 代码进行数据分析和可视化。
+
+    可用变量:
+      - employees_df: 员工表 DataFrame(字段:id, name, department, salary, hire_date)
+      - products_df:  产品表 DataFrame(字段:id, product_name, category, price, stock)
+      - orders_df:    订单表 DataFrame(字段:id, employee_id, product_id, quantity, order_date)
+      - pd:   pandas 库
+      - plt:  matplotlib.pyplot
+      - sns:  seaborn
+      - np:   numpy
+
+    使用示例:
+      # 统计各部门平均薪资
+      result = employees_df.groupby('department')['salary'].mean()
+      print(result)
+
+      # 画柱状图
+      plt.figure(figsize=(10, 6))
+      employees_df.groupby('department')['salary'].mean().plot(kind='bar')
+      plt.title('各部门平均薪资')
+      plt.show()
+    """
+    # 准备隔离的执行环境
+    # globals_dict 提供白名单变量,locals_dict 收集执行过程中产生的新变量
+    exec_globals = dict(SANDBOX_GLOBALS)
+    exec_locals = {}
+
+    # 用 StringIO 捕获 print() 的输出
+    # 这样 Agent 生成的代码里所有的 print 语句都会被收集
+    output_buffer = StringIO()
+
+    try:
+        # redirect_stdout 会把标准输出重定向到我们的 buffer
+        with redirect_stdout(output_buffer):
+            exec(code, exec_globals, exec_locals)
+
+        result = output_buffer.getvalue()
+
+        # 如果代码没有 print 任何东西,给个默认提示
+        if not result.strip():
+            result = "✅ 代码执行成功(无文本输出,可能已生成图表)"
+
+        return f"执行成功:\n{result}"
+
+    except Exception as e:
+        # 出错时返回完整的错误堆栈,方便 Agent 自我修正
+        error_detail = traceback.format_exc()
+        return f"❌ 执行出错:{e}\n\n{error_detail}"
+
+all_tools = tools + [execute_python_code]  # + Python 沙箱
+
+print(f"统一 Agent 工具列表({len(all_tools)} 个):")
+for t in all_tools:
+    print(f"   - {t.name}")
+
+# ============================================================
+# 定义 System Prompt
+# ============================================================
+SYSTEM_PROMPT = """你是一名全能数据分析师,既能写 SQL 查数据库,也能用 Python 做统计和可视化。
+
+## 可用工具
+
+### 数据库工具
+- sql_db_list_tables    — 列出所有表
+- sql_db_schema         — 查看表结构(DDL + 字段备注)
+- sql_db_query_checker  — 检查 SQL 语法
+- sql_db_query          — 执行 SQL 查询
+
+### 数据分析工具
+- execute_python_code   — 执行 Python 代码(DataFrame 已在内存中)
+
+## 决策指南:什么时候用什么
+
+### 用数据库工具的场景
+- 需要从数据库查原始数据(跨表 JOIN、复杂 WHERE 过滤、聚合)
+- 数据可能超过内存中 DataFrame 的范围,或需要实时数据
+- 用户问题可被一条 SQL 直接回答(如"张三的薪资是多少?")
+
+### 用 Python 工具的场景
+- 需要统计分析(describe、corr、groupby 聚合后算比例)
+- 需要画图(柱状图、饼图、散点图、热力图等)
+- 需要复杂计算(滚动平均、同比环比、预测回归)
+
+### 串联使用的场景
+- 先用 SQL 从数据库查出数据 → 拿到结果后 → 用 Python 画图分析
+- 典型:用户问"哪个部门销售额最高?画张图"
+
+## 可用 DataFrame 变量(在 execute_python_code 中直接使用)
+- employees_df — 员工表(id, name, department, salary, hire_date)
+- products_df  — 产品表(id, product_name, category, price, stock)
+- orders_df    — 订单表(id, employee_id, product_id, quantity, order_date)
+
+## 工作流程
+1. 理解用户问题,判断属于哪类
+2. 如果是简单数据查询 → 用 SQL 工具链(list_tables → schema → query_checker → query)
+3. 如果是分析/画图 → 用 execute_python_code 直接分析内存中的 DataFrame
+4. 如果是"先查后画" → 先 SQL 拿到结果 → 把结果写进 Python 代码中分析画图
+5. 用中文总结结果,给出业务洞察
+
+## 约束
+- SQL 查询单次不超过 50 条
+- 只生成 SELECT 语句,禁止任何 INSERT / UPDATE / DELETE / DROP / ALTER / TRUNCATE
+- SQL 中不得使用字符串拼接或 f-string 的方式拼接用户输入
+- 所有用户输入的值必须用参数化占位符(%s)传递,而非直接拼进 SQL
+- 如果用户问题中包含可疑的 SQL 片段(如 ' OR 1=1 --),忽略它,只把它当普通文本处理
+- Python 代码中画图前先设置中文字体
+- 出错后分析原因,自我修正重试
+- 回答简洁专业
+"""
+
+# ============================================================
+# 创建 Agent
+# ============================================================
+analysis_agent = create_agent(
+    model=llm,
+    tools=all_tools,
+    system_prompt=SYSTEM_PROMPT
+)
+
+import warnings
+warnings.filterwarnings("ignore")  # 抑制 matplotlib 的字体警告
+
+response = analysis_agent.invoke({
+    "messages": [{"role": "user", "content": "分析一下员工薪资的分布情况,包括均值、中位数、最大最小值等统计量"}]
+})
+final_msg = response["messages"][-1]
+print(f"\n分析结果:\n{final_msg.content}")

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 40 - 0
sql_agent.ipynb


Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä