Jelajahi Sumber

first commit

QIN 9 jam lalu
melakukan
5a3d0bf7e8
8 mengubah file dengan 494 tambahan dan 0 penghapusan
  1. 11 0
      .gitignore
  2. 1 0
      .python-version
  3. 0 0
      README.md
  4. 102 0
      init.py
  5. 6 0
      main.py
  6. 7 0
      pyproject.toml
  7. 359 0
      sql_agent.ipynb
  8. 8 0
      uv.lock

+ 11 - 0
.gitignore

@@ -0,0 +1,11 @@
+# Python-generated files
+__pycache__/
+*.py[oc]
+build/
+dist/
+wheels/
+*.egg-info
+
+# Virtual environments
+.venv
+.env

+ 1 - 0
.python-version

@@ -0,0 +1 @@
+3.11

+ 0 - 0
README.md


+ 102 - 0
init.py

@@ -0,0 +1,102 @@
+import os
+import mysql.connector
+from dotenv import load_dotenv 
+
+load_dotenv()  # 默认加载当前目录或父目录中的 .env 文件
+# ============================================================
+# 第一步:建立数据库连接
+# ============================================================
+# 使用环境变量管理敏感信息,这是生产级代码的基本规范
+conn = mysql.connector.connect(
+    host=os.getenv("DB_HOST", "127.0.0.1"),
+    port=int(os.getenv("DB_PORT", 3307)),
+    user=os.getenv("DB_USER", "root"),
+    password=os.getenv("DB_PASSWORD", ""),
+    database=os.getenv("DB_NAME", "analytics_demo"),
+)
+cursor = conn.cursor()
+
+# ============================================================
+# 第二步:创建表结构
+# ============================================================
+
+# 员工表:存储公司内部人员信息
+# 用于分析:部门人数分布、薪资结构、入职趋势等
+cursor.execute("""
+CREATE TABLE IF NOT EXISTS employees (
+    id          INT PRIMARY KEY AUTO_INCREMENT,  -- 自增主键
+    name        VARCHAR(50)  NOT NULL,           -- 姓名
+    department  VARCHAR(50)  NOT NULL,           -- 所属部门
+    salary      DECIMAL(10,2) NOT NULL,          -- 月薪(精确到分)
+    hire_date   DATE                             -- 入职日期
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+""")
+
+# 产品表:存储在售商品信息
+# 用于分析:品类销售、价格分布、库存预警等
+cursor.execute("""
+CREATE TABLE IF NOT EXISTS products (
+    id           INT PRIMARY KEY AUTO_INCREMENT,  -- 自增主键
+    product_name VARCHAR(100) NOT NULL,           -- 商品名称
+    category     VARCHAR(50)  NOT NULL,           -- 商品分类
+    price        DECIMAL(10,2) NOT NULL,          -- 单价
+    stock        INT DEFAULT 0                    -- 当前库存量
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+""")
+
+# 订单表:记录每笔交易
+# 用于分析:销售趋势、员工绩效、产品热度等
+cursor.execute("""
+CREATE TABLE IF NOT EXISTS orders (
+    id           INT PRIMARY KEY AUTO_INCREMENT,  -- 自增主键
+    employee_id  INT NOT NULL,                    -- 下单员工(外键)
+    product_id   INT NOT NULL,                    -- 购买商品(外键)
+    quantity     INT NOT NULL,                    -- 购买数量
+    order_date   DATE NOT NULL,                   -- 下单日期
+    FOREIGN KEY (employee_id) REFERENCES employees(id),   -- 关联员工表
+    FOREIGN KEY (product_id)  REFERENCES products(id)     -- 关联产品表
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+""")
+
+# ============================================================
+# 第三步:插入演示数据
+# ============================================================
+
+# 员工数据:5 个员工,分属 3 个部门
+employees_data = [
+    (1, "张三", "技术部",   20000.00, "2023-01-15"),
+    (2, "李四", "销售部",   11000.00, "2023-02-20"),
+    (3, "王五", "技术部",   16000.00, "2022-11-10"),
+    (4, "赵六", "人力资源", 5000.00, "2023-03-01"),
+    (5, "钱七", "销售部",   17000.00, "2022-12-05"),
+]
+
+# 产品数据:4 款商品,覆盖 2 个品类
+products_data = [
+    (1, "笔记本电脑", "电子产品", 6999.00, 500),
+    (2, "机械键盘",   "电子产品", 399.00,  1000),
+    (3, "办公椅",     "办公用品", 499.00,  300),
+    (4, "显示器",     "电子产品", 1200.00, 400),
+]
+
+# 订单数据:5 笔订单,模拟真实购买行为
+orders_data = [
+    (1, 1, 1, 2, "2024-01-15"),  # 张三买了 2 台笔记本
+    (2, 2, 2, 15, "2024-01-16"),  # 李四买了 15 个键盘
+    (3, 3, 1, 10, "2024-01-17"),  # 王五买了 10 台笔记本
+    (4, 5, 3, 6, "2024-01-18"),  # 钱七买了 6 把办公椅
+    (5, 2, 4, 5, "2024-01-19"),  # 李四买了 5 台显示器
+]
+
+# executemany 批量插入,比逐条 insert 高效得多
+cursor.executemany("INSERT IGNORE INTO employees VALUES (%s,%s,%s,%s,%s)", employees_data)
+cursor.executemany("INSERT IGNORE INTO products  VALUES (%s,%s,%s,%s,%s)", products_data)
+cursor.executemany("INSERT IGNORE INTO orders     VALUES (%s,%s,%s,%s,%s)", orders_data)
+
+conn.commit()
+conn.close()
+
+print("✅ 数据库初始化完成")
+print("   - employees 表:5 条员工记录")
+print("   - products  表:4 条产品记录")
+print("   - orders    表:5 条订单记录")

+ 6 - 0
main.py

@@ -0,0 +1,6 @@
+def main():
+    print("Hello from analytics-demo!")
+
+
+if __name__ == "__main__":
+    main()

+ 7 - 0
pyproject.toml

@@ -0,0 +1,7 @@
+[project]
+name = "analytics-demo"
+version = "0.1.0"
+description = "Add your description here"
+readme = "README.md"
+requires-python = ">=3.11"
+dependencies = []

+ 359 - 0
sql_agent.ipynb

@@ -0,0 +1,359 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "2671054f",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\34450\\AppData\\Local\\Temp\\ipykernel_16268\\3464868009.py:2: DeprecationWarning: `langchain-community` is being sunset and is no longer actively maintained. See https://github.com/langchain-ai/langchain-community/issues/674 for details and migration guidance toward standalone integration packages.\n",
+      "  from langchain_community.utilities import SQLDatabase\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "数据库连接成功\n",
+      "   可用表:['employees', 'orders', 'products']\n",
+      "\n",
+      "SQL 工具包已加载(4 个工具):\n",
+      "   - sql_db_query\n",
+      "   - sql_db_schema\n",
+      "   - sql_db_list_tables\n",
+      "   - sql_db_query_checker\n",
+      "\n",
+      "NL2SQL Agent 创建完成,可以开始提问了!\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "from langchain_community.utilities import SQLDatabase\n",
+    "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
+    "from langchain_openai import ChatOpenAI\n",
+    "from langchain.agents import create_agent  # langchain 1.3.1 的新 API\n",
+    "from dotenv import load_dotenv \n",
+    "\n",
+    "load_dotenv()  # 默认加载当前目录或父目录中的 .env 文件\n",
+    "# ============================================================\n",
+    "# 第一步:连接数据库\n",
+    "# ============================================================\n",
+    "# SQLDatabase.from_uri 接受标准的数据库连接 URI\n",
+    "# LangChain 内部会用 SQLAlchemy 管理连接池\n",
+    "db_uri = (\n",
+    "    f\"mysql+pymysql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}\"\n",
+    "    f\"@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}\"\n",
+    ")\n",
+    "db = SQLDatabase.from_uri(db_uri)\n",
+    "\n",
+    "# 验证连接:打印可用的表名\n",
+    "print(f\"数据库连接成功\")\n",
+    "print(f\"   可用表:{db.get_usable_table_names()}\")\n",
+    "\n",
+    "# ============================================================\n",
+    "# 第二步:初始化大模型\n",
+    "# ============================================================\n",
+    "llm = ChatOpenAI(\n",
+    "    model=os.getenv(\"DEEPSEEK_MODEL\"),\n",
+    "    api_key=os.getenv(\"DEEPSEEK_API_KEY\"),\n",
+    "    base_url=os.getenv(\"DEEPSEEK_BASE_URL\"),\n",
+    "    temperature=0,  # SQL 生成需要确定性输出\n",
+    ")\n",
+    "\n",
+    "# ============================================================\n",
+    "# 第三步:创建 SQL 工具包\n",
+    "# ============================================================\n",
+    "# SQLDatabaseToolkit 会自动注册 4 个工具:\n",
+    "#   1. sql_db_list_tables    — 列出数据库中所有表\n",
+    "#   2. sql_db_schema         — 获取指定表的 DDL 结构\n",
+    "#   3. sql_db_query_checker  — 检查 SQL 语法是否正确\n",
+    "#   4. sql_db_query          — 执行 SQL 并返回结果\n",
+    "toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n",
+    "tools = toolkit.get_tools()\n",
+    "\n",
+    "print(f\"\\nSQL 工具包已加载({len(tools)} 个工具):\")\n",
+    "for t in tools:\n",
+    "    print(f\"   - {t.name}\")\n",
+    "\n",
+    "# ============================================================\n",
+    "# 第四步:定义 System Prompt(Agent 的\"岗位说明书\")\n",
+    "# ============================================================\n",
+    "SQL_AGENT_PROMPT = \"\"\"你是一名专业的 SQL 数据分析师。\n",
+    "\n",
+    "## 工作流程\n",
+    "1. 先用 sql_db_list_tables 查看数据库中有哪些表\n",
+    "2. 用 sql_db_schema 获取相关表的字段结构和类型\n",
+    "3. 生成 SQL 之前,用 sql_db_query_checker 检查语法\n",
+    "4. 确认无误后,用 sql_db_query 执行查询\n",
+    "5. 用中文总结查询结果,给出简洁的业务洞察\n",
+    "\n",
+    "## 约束\n",
+    "- 只使用数据库中实际存在的表和字段,不要凭空编造\n",
+    "- 单次查询结果限制在 50 条以内\n",
+    "- 如果查询出错,分析错误原因后重新生成 SQL\n",
+    "- 回答要简洁专业,不要啰嗦\n",
+    "\"\"\"\n",
+    "\n",
+    "# ============================================================\n",
+    "# 第五步:组装 Agent(langchain 1.3.1 一行搞定)\n",
+    "# ============================================================\n",
+    "# create_agent 返回一个可直接调用的 Runnable,无需再套 AgentExecutor\n",
+    "# 它内部自动处理工具调用、循环迭代、错误重试等逻辑\n",
+    "sql_agent = create_agent(\n",
+    "    model=llm,\n",
+    "    tools=tools,\n",
+    "    system_prompt=SQL_AGENT_PROMPT,\n",
+    ")\n",
+    "\n",
+    "print(\"\\nNL2SQL Agent 创建完成,可以开始提问了!\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "bd1fef44",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "回答:公司共有 **5 名员工**。\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 简单问题:Agent 只需一条 COUNT SQL 就能搞定\n",
+    "# create_agent 的 invoke 接口使用 messages 格式\n",
+    "response = sql_agent.invoke({\n",
+    "    \"messages\": [{\"role\": \"user\", \"content\": \"公司一共有多少名员工?\"}]\n",
+    "})\n",
+    "\n",
+    "# 最终回答在最后一条消息里\n",
+    "final_msg = response[\"messages\"][-1]\n",
+    "print(f\"回答:{final_msg.content}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "a7025030",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "回答:## 查询结果\n",
+      "\n",
+      "技术部共有 **2 名员工**,按薪资从高到低排序如下:\n",
+      "\n",
+      "| 姓名 | 薪资 |\n",
+      "|------|------|\n",
+      "| 张三 | 20,000.00 |\n",
+      "| 王五 | 16,000.00 |\n",
+      "\n",
+      "**业务洞察:**\n",
+      "- 技术部员工平均薪资为 **18,000 元**。\n",
+      "- 张三薪资最高(20,000元),比第二名王五高出 **25%**。\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 中等难度:需要 JOIN + WHERE + GROUP BY\n",
+    "response = sql_agent.invoke({\n",
+    "    \"messages\": [{\"role\": \"user\", \"content\": \"列出技术部所有员工的姓名和薪资,按薪资从高到低排序\"}]\n",
+    "})\n",
+    "final_msg = response[\"messages\"][-1]\n",
+    "print(f\"回答:{final_msg.content}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "cb41a941",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "回答:查询结果如下:\n",
+      "\n",
+      "---\n",
+      "\n",
+      "### 📊 销售部订单统计\n",
+      "\n",
+      "| 指标 | 数据 |\n",
+      "|------|------|\n",
+      "| **订单总数** | **3 笔** |\n",
+      "| **订单总金额** | **¥14,979.00** |\n",
+      "\n",
+      "### 👥 员工明细\n",
+      "\n",
+      "| 员工 | 订单数 | 金额 |\n",
+      "|-----|:-----:|:----:|\n",
+      "| **李四** | 2 笔 | ¥11,985.00 |\n",
+      "| **钱七** | 1 笔 | ¥2,994.00 |\n",
+      "\n",
+      "**业务洞察:**\n",
+      "- 销售部共 **2 名员工**,累计下单 **3 笔**,总金额 **14,979 元**。\n",
+      "- 李四贡献了总金额的 **80%**(11,985 元),是销售部的核心业务人员。\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 高难度:需要 JOIN 三张表 + 聚合计算\n",
+    "response = sql_agent.invoke({\n",
+    "    \"messages\": [{\"role\": \"user\", \"content\": \"销售部的员工总共下了多少订单?订单总金额是多少?\"}]\n",
+    "})\n",
+    "final_msg = response[\"messages\"][-1]\n",
+    "print(f\"回答:{final_msg.content}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "5ce44c51",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "最终回答:每个部门各有多少人?\n",
+      "\n",
+      "步骤 2 [AIMessage]:\n",
+      "  调用工具:sql_db_list_tables\n",
+      "  参数:{}\n",
+      "  工具返回:employees, orders, products\n",
+      "\n",
+      "步骤 4 [AIMessage]:\n",
+      "  调用工具:sql_db_schema\n",
+      "  参数:{'table_names': 'employees'}\n",
+      "  工具返回:\n",
+      "CREATE TABLE employees (\n",
+      "\tid INTEGER NOT NULL AUTO_INCREMENT, \n",
+      "\tname VARCHAR(50) NOT NULL, \n",
+      "\tdepartment VARCHAR(50) NOT NULL, \n",
+      "\tsalary DECIMAL(10, 2) NOT NULL, \n",
+      "\thire_date DATE, \n",
+      "\tPRIMARY KEY (id)\n",
+      ")E\n",
+      "\n",
+      "步骤 6 [AIMessage]:\n",
+      "  调用工具:sql_db_query_checker\n",
+      "  参数:{'query': 'SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC'}\n",
+      "  工具返回:SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC\n",
+      "\n",
+      "步骤 8 [AIMessage]:\n",
+      "  调用工具:sql_db_query\n",
+      "  参数:{'query': 'SELECT department, COUNT(*) AS 人数 FROM employees GROUP BY department ORDER BY 人数 DESC'}\n",
+      "  工具返回:[('技术部', 2), ('销售部', 2), ('人力资源', 1)]\n",
+      "\n",
+      "最终回答:查询完成,以下是 **各部门人数统计**:\n",
+      "\n",
+      "| 部门 | 人数 |\n",
+      "|------|:----:|\n",
+      "| 🛠️ 技术部 | 2 人 |\n",
+      "| 📊 销售部 | 2 人 |\n",
+      "| 👥 人力资源 | 1 人 |\n",
+      "\n",
+      "**业务洞察:** 目前公司共有 5 名员工,技术部和销售部各占 2 人,人力资源部 1 人。技术部和销售部人员配置相对均衡。\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 遍历所有消息,还原 Agent 的完整推理链\n",
+    "response = sql_agent.invoke({\n",
+    "    \"messages\": [{\"role\": \"user\", \"content\": \"每个部门各有多少人?\"}]\n",
+    "})\n",
+    "\n",
+    "for i, msg in enumerate(response[\"messages\"]):\n",
+    "    msg_type = msg.__class__.__name__\n",
+    "\n",
+    "    # AIMessage 中如果有 tool_calls,说明 Agent 调用了工具\n",
+    "    if hasattr(msg, \"tool_calls\") and msg.tool_calls:\n",
+    "        print(f\"\\n步骤 {i+1} [{msg_type}]:\")\n",
+    "        for tc in msg.tool_calls:\n",
+    "            print(f\"  调用工具:{tc['name']}\")\n",
+    "            print(f\"  参数:{tc.get('args', {})}\")\n",
+    "\n",
+    "    # ToolMessage 是工具返回的结果\n",
+    "    elif msg_type == \"ToolMessage\":\n",
+    "        print(f\"  工具返回:{msg.content[:200]}\")\n",
+    "\n",
+    "    # AIMessage 的最终文本回答\n",
+    "    elif msg.content:\n",
+    "        print(f\"\\n最终回答:{msg.content}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "019a2781",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import pandas as pd\n",
+    "import matplotlib.pyplot as plt\n",
+    "import seaborn as sns\n",
+    "import numpy as np\n",
+    "\n",
+    "# ============================================================\n",
+    "# 配置 matplotlib 中文显示\n",
+    "# ============================================================\n",
+    "# Windows 用 SimHei(黑体),macOS 用 PingFang SC\n",
+    "# 如果还是乱码,试试安装 fonts-noto-cjk 并清除缓存\n",
+    "plt.rcParams[\"font.sans-serif\"] = [\"SimHei\", \"PingFang SC\", \"DejaVu Sans\"]\n",
+    "plt.rcParams[\"axes.unicode_minus\"] = False  # 解决负号显示为方块的问题\n",
+    "\n",
+    "# ============================================================\n",
+    "# 从数据库加载数据到 Pandas DataFrame\n",
+    "# ============================================================\n",
+    "# 用 SQLAlchemy engine 复用连接,避免重复创建连接池\n",
+    "from sqlalchemy import create_engine\n",
+    "\n",
+    "engine = create_engine(db_uri)\n",
+    "\n",
+    "employees_df = pd.read_sql(\"SELECT * FROM employees\", engine)\n",
+    "products_df  = pd.read_sql(\"SELECT * FROM products\", engine)\n",
+    "orders_df    = pd.read_sql(\"SELECT * FROM orders\", engine)\n",
+    "\n",
+    "print(\"✅ 数据加载完成\")\n",
+    "print(f\"   员工表:{len(employees_df)} 行 × {len(employees_df.columns)} 列\")\n",
+    "print(f\"   产品表:{len(products_df)} 行 × {len(products_df.columns)} 列\")\n",
+    "print(f\"   订单表:{len(orders_df)} 行 × {len(orders_df.columns)} 列\")\n",
+    "print(f\"\\n📋 员工表示例:\")\n",
+    "print(employees_df.head(3).to_string(index=False))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "analytics_demo",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.15"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

+ 8 - 0
uv.lock

@@ -0,0 +1,8 @@
+version = 1
+revision = 3
+requires-python = ">=3.11"
+
+[[package]]
+name = "analytics-demo"
+version = "0.1.0"
+source = { virtual = "." }