from dataclasses import dataclass, asdict
from typing import List, Optional
import json
import sqlite3
@dataclass
class Message:
role: str # "user" 或 "assistant"
content: str
is_anchor: bool = False # 是否为锚点信息
def to_dict(self):
return asdict(self)
@staticmethod
def from_dict(data):
return Message(**data)
class ConversationManager:
def __init__(self, session_id: str, db_path="chat.db", max_recent=5):
self.session_id = session_id
self.db_path = db_path
self.messages: List[Message] = []
self.anchors: List[str] = []
self.summary: str = ""
self.max_recent = max_recent
self._init_db()
self.load_history()
def _init_db(self):
"""初始化 SQLite 数据库表结构"""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
summary TEXT,
anchors TEXT,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
role TEXT,
content TEXT,
is_anchor BOOLEAN,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
)
""")
def load_history(self):
"""从数据库加载历史"""
with sqlite3.connect(self.db_path) as conn:
# 加载会话元数据(摘要和锚点)
cursor = conn.execute(
"SELECT summary, anchors FROM sessions WHERE session_id = ?",
(self.session_id,)
)
row = cursor.fetchone()
if row:
self.summary = row[0]
self.anchors = json.loads(row[1]) if row[1] else []
# 加载最近的消息
# 注意:实际生产中可能只需要加载最近 N 条,早期的已压缩
cursor = conn.execute(
"""
SELECT role, content, is_anchor FROM messages
WHERE session_id = ?
ORDER BY created_at ASC
""",
(self.session_id,)
)
self.messages = [Message(*r) for r in cursor.fetchall()]
def save_snapshot(self):
"""保存当前状态到数据库"""
with sqlite3.connect(self.db_path) as conn:
# 更新会话元数据
conn.execute(
"""
INSERT OR REPLACE INTO sessions (session_id, summary, anchors)
VALUES (?, ?, ?)
""",
(self.session_id, self.summary, json.dumps(self.anchors))
)
# 简单的全量保存(生产环境应改为增量保存)
conn.execute("DELETE FROM messages WHERE session_id = ?", (self.session_id,))
for msg in self.messages:
conn.execute(
"""
INSERT INTO messages (session_id, role, content, is_anchor)
VALUES (?, ?, ?, ?)
""",
(self.session_id, msg.role, msg.content, msg.is_anchor)
)
def add_message(self, role: str, content: str):
# ... (与之前逻辑相同:检测锚点、添加消息) ...
is_anchor = self._detect_anchor(content)
msg = Message(role, content, is_anchor)
self.messages.append(msg)
if is_anchor:
self.anchors.append(content)
# 实时保存(或采用异步批量保存)
self.save_snapshot()
# 触发压缩检查
if len(self.messages) > self.max_recent * 2:
self._compress_history()
self.save_snapshot() # 压缩后再次保存更新状态
def _detect_anchor(self, content: str) -> bool:
# ... (同上) ...
anchor_patterns = ["订单号", "手机号", "必须", "预算"]
return any(p in content for p in anchor_patterns)
def _compress_history(self):
# ... (压缩逻辑同上) ...
# 压缩后清理由 self.messages 移除的消息,通过 save_snapshot 同步到 DB
pass
def build_context(self) -> str:
# ... (构建上下文逻辑同上) ...
parts = []
if self.anchors: parts.append(f"【关键约束】\n{self.anchors}")
if self.summary: parts.append(f"【历史摘要】\n{self.summary}")
recent = "\n".join([f"{m.role}: {m.content}" for m in self.messages[-self.max_recent:]])
parts.append(f"【最近对话】\n{recent}")
return "\n\n".join(parts)